Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hamming, jensen-shannon, kl-divergence, correlation and russellrao distance metrics #4155

Merged
merged 23 commits into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
dfc28a4
Add interfaces in cuml for hamming, correlation, jensen-shannon, kl-d…
mdoijade Aug 6, 2021
ad67053
fix clang formatting issues
mdoijade Aug 6, 2021
38729dc
add all new distances to main API and fix function name in correlatio…
mdoijade Aug 6, 2021
0a734fd
add python interfaces for all new dist metrics, with tests for all wo…
mdoijade Aug 10, 2021
aa072ac
add test support for kl-divergence dist metric
mdoijade Aug 11, 2021
38e1067
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Aug 11, 2021
6dd14bf
pin mdoijade raft fork for testing change
mdoijade Aug 11, 2021
911eef8
fix flake formating issues in test_metrics
mdoijade Aug 11, 2021
ae9f7a2
temp commit to trigger ci
mdoijade Aug 12, 2021
5fc77f7
temp commit to trigger ci
mdoijade Aug 12, 2021
90c8f16
temp commit to trigger ci to check updated raft changes
mdoijade Aug 12, 2021
1d5c966
temp commit to trigger ci to check updated raft changes
mdoijade Aug 12, 2021
8225ee0
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Aug 13, 2021
3f6de71
temp commit to test new raft commits
mdoijade Aug 23, 2021
19d44db
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Aug 23, 2021
b71866a
revert raft mdoijade fork as raft PR is merged now
mdoijade Aug 26, 2021
3ef11bc
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Aug 26, 2021
1a1dd71
remove redundant metric arg and switch based on it on the APIs which …
mdoijade Aug 30, 2021
96c0240
merge branch-21.10
mdoijade Aug 31, 2021
675b661
Add udevice_vector changes to new distances
mdoijade Aug 31, 2021
ac4f067
fix clang format issues
mdoijade Aug 31, 2021
8e2f01d
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Sep 1, 2021
6427584
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Sep 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,16 @@ if(BUILD_CUML_CPP_LIBRARY)
src/metrics/pairwise_distance.cu
src/metrics/pairwise_distance_canberra.cu
src/metrics/pairwise_distance_chebyshev.cu
src/metrics/pairwise_distance_correlation.cu
src/metrics/pairwise_distance_cosine.cu
src/metrics/pairwise_distance_euclidean.cu
src/metrics/pairwise_distance_hamming.cu
src/metrics/pairwise_distance_hellinger.cu
src/metrics/pairwise_distance_jensen_shannon.cu
src/metrics/pairwise_distance_kl_divergence.cu
src/metrics/pairwise_distance_l1.cu
src/metrics/pairwise_distance_minkowski.cu
src/metrics/pairwise_distance_russell_rao.cu
src/metrics/r2_score.cu
src/metrics/rand_index.cu
src/metrics/silhouette_score.cu
Expand Down
4 changes: 2 additions & 2 deletions cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ function(find_and_configure_raft)
BUILD_EXPORT_SET cuml-exports
INSTALL_EXPORT_SET cuml-exports
CPM_ARGS
GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git
GIT_TAG ${PKG_PINNED_TAG}
GIT_REPOSITORY https://github.com/mdoijade/raft.git
GIT_TAG additionalDistPrims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to update this now that the raft side has been merged

SOURCE_SUBDIR cpp
OPTIONS
"BUILD_TESTS OFF"
Expand Down
35 changes: 35 additions & 0 deletions cpp/src/metrics/pairwise_distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@
#include <raft/sparse/distance/distance.cuh>
#include "pairwise_distance_canberra.cuh"
#include "pairwise_distance_chebyshev.cuh"
#include "pairwise_distance_correlation.cuh"
#include "pairwise_distance_cosine.cuh"
#include "pairwise_distance_euclidean.cuh"
#include "pairwise_distance_hamming.cuh"
#include "pairwise_distance_hellinger.cuh"
#include "pairwise_distance_jensen_shannon.cuh"
#include "pairwise_distance_kl_divergence.cuh"
#include "pairwise_distance_l1.cuh"
#include "pairwise_distance_minkowski.cuh"
#include "pairwise_distance_russell_rao.cuh"

namespace ML {

Expand Down Expand Up @@ -67,6 +72,21 @@ void pairwise_distance(const raft::handle_t& handle,
case raft::distance::DistanceType::Canberra:
pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CorrelationExpanded:
pairwise_distance_correlation(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HammingUnexpanded:
pairwise_distance_hamming(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::JensenShannon:
pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::KLDivergence:
pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::RusselRaoExpanded:
pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
}
Expand Down Expand Up @@ -107,6 +127,21 @@ void pairwise_distance(const raft::handle_t& handle,
case raft::distance::DistanceType::Canberra:
pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CorrelationExpanded:
pairwise_distance_correlation(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HammingUnexpanded:
pairwise_distance_hamming(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::JensenShannon:
pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::KLDivergence:
pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::RusselRaoExpanded:
pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
}
Expand Down
10 changes: 1 addition & 9 deletions cpp/src/metrics/pairwise_distance_canberra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
* limitations under the License.
*/

//#include <cuml/metrics/metrics.hpp>
#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>
#include "pairwise_distance_canberra.cuh"

namespace ML {

Expand All @@ -37,10 +37,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
raft::mr::device::buffer<char> workspace(handle.get_device_allocator(), handle.get_stream(), 1);

// Call the distance function
/* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric,
handle.get_stream(), isRowMajor,
metric_arg);*/

switch (metric) {
case raft::distance::DistanceType::Canberra:
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Canberra>(
Expand All @@ -65,10 +61,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
raft::mr::device::buffer<char> workspace(handle.get_device_allocator(), handle.get_stream(), 1);

// Call the distance function
/* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric,
handle.get_stream(), isRowMajor,
metric_arg);*/

switch (metric) {
case raft::distance::DistanceType::Canberra:
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Canberra>(
Expand Down
76 changes: 76 additions & 0 deletions cpp/src/metrics/pairwise_distance_correlation.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>
#include "pairwise_distance_correlation.cuh"

namespace ML {

namespace Metrics {
void pairwise_distance_correlation(const raft::handle_t& handle,
const double* x,
const double* y,
double* dist,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
raft::mr::device::buffer<char> workspace(handle.get_device_allocator(), handle.get_stream(), 1);

// Call the distance function
switch (metric) {
case raft::distance::DistanceType::CorrelationExpanded:
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::CorrelationExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
}

void pairwise_distance_correlation(const raft::handle_t& handle,
const float* x,
const float* y,
float* dist,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
raft::mr::device::buffer<char> workspace(handle.get_device_allocator(), handle.get_stream(), 1);

// Call the distance function
switch (metric) {
case raft::distance::DistanceType::CorrelationExpanded:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the switch state propagated into this in case there is also an unexpanded version in the future? (Asking for the other distances as well)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No there is no unexpanded version planned for this or other distance metrics. should I replace switch() with an if condition ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions are named after the associated distance measure so I would call a function named pairwise_distance_correlation expecting it to compute the correlation distance. If it's not possible that this will ever compute a distance other than correlation, why are we passing the DistanceType into these functions at all? Rather than using an if condition, I would propose removing the metric argument altogether and just having the function compute the distance for which its name implies. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I think the metric arg got propagated to all distances due to euclidean dist having multiple sub-implementations.
I've removed this redundant check & arg now for all non-euclidean metrics which doesn't have multiple implementations.

raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::CorrelationExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
}

} // namespace Metrics
} // namespace ML
49 changes: 49 additions & 0 deletions cpp/src/metrics/pairwise_distance_correlation.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>

namespace ML {

namespace Metrics {
void pairwise_distance_correlation(const raft::handle_t& handle,
const double* x,
const double* y,
double* dist,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg);

void pairwise_distance_correlation(const raft::handle_t& handle,
const float* x,
const float* y,
float* dist,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg);

} // namespace Metrics
} // namespace ML
76 changes: 76 additions & 0 deletions cpp/src/metrics/pairwise_distance_hamming.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>
#include "pairwise_distance_hamming.cuh"

namespace ML {

namespace Metrics {
void pairwise_distance_hamming(const raft::handle_t& handle,
const double* x,
const double* y,
double* dist,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
raft::mr::device::buffer<char> workspace(handle.get_device_allocator(), handle.get_stream(), 0);

// Call the distance function
switch (metric) {
case raft::distance::DistanceType::HammingUnexpanded:
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::HammingUnexpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
}

void pairwise_distance_hamming(const raft::handle_t& handle,
const float* x,
const float* y,
float* dist,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
raft::mr::device::buffer<char> workspace(handle.get_device_allocator(), handle.get_stream(), 0);

// Call the distance function
switch (metric) {
case raft::distance::DistanceType::HammingUnexpanded:
raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::HammingUnexpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
}

} // namespace Metrics
} // namespace ML
49 changes: 49 additions & 0 deletions cpp/src/metrics/pairwise_distance_hamming.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>

namespace ML {

namespace Metrics {
void pairwise_distance_hamming(const raft::handle_t& handle,
const double* x,
const double* y,
double* dist,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg);

void pairwise_distance_hamming(const raft::handle_t& handle,
const float* x,
const float* y,
float* dist,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg);

} // namespace Metrics
} // namespace ML
Loading