From 55ddc8b5c5882e8b8e51df0eff66cb622f0d67ff Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Tue, 14 Sep 2021 18:46:59 +0530 Subject: [PATCH] Add hamming, jensen-shannon, kl-divergence, correlation and russellrao distance metrics (#4155) -- This PR depends on RAFT PR - https://github.com/rapidsai/raft/pull/306 -- Adds cpp & python interfaces for these distance metrics with pytest support for each of them. -- also remove redundant commented code in canberra distance metric Authors: - Mahesh Doijade (https://github.com/mdoijade) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuml/pull/4155 --- cpp/CMakeLists.txt | 5 ++ cpp/src/metrics/pairwise_distance.cu | 59 +++++++++++++---- cpp/src/metrics/pairwise_distance_canberra.cu | 30 ++------- .../metrics/pairwise_distance_canberra.cuh | 2 - .../metrics/pairwise_distance_chebyshev.cu | 20 ++---- .../metrics/pairwise_distance_chebyshev.cuh | 2 - .../metrics/pairwise_distance_correlation.cu | 65 +++++++++++++++++++ .../metrics/pairwise_distance_correlation.cuh | 47 ++++++++++++++ cpp/src/metrics/pairwise_distance_cosine.cu | 23 ++----- cpp/src/metrics/pairwise_distance_cosine.cuh | 2 - cpp/src/metrics/pairwise_distance_hamming.cu | 65 +++++++++++++++++++ cpp/src/metrics/pairwise_distance_hamming.cuh | 47 ++++++++++++++ .../metrics/pairwise_distance_hellinger.cu | 24 ++----- .../metrics/pairwise_distance_hellinger.cuh | 2 - .../pairwise_distance_jensen_shannon.cu | 63 ++++++++++++++++++ .../pairwise_distance_jensen_shannon.cuh | 47 ++++++++++++++ .../pairwise_distance_kl_divergence.cu | 63 ++++++++++++++++++ .../pairwise_distance_kl_divergence.cuh | 47 ++++++++++++++ cpp/src/metrics/pairwise_distance_l1.cu | 20 ++---- cpp/src/metrics/pairwise_distance_l1.cuh | 2 - .../metrics/pairwise_distance_minkowski.cu | 22 ++----- .../metrics/pairwise_distance_minkowski.cuh | 2 - .../metrics/pairwise_distance_russell_rao.cu | 65 +++++++++++++++++++ .../metrics/pairwise_distance_russell_rao.cuh | 47 ++++++++++++++ python/cuml/metrics/distance_type.pxd | 3 + python/cuml/metrics/pairwise_distances.pyx | 18 ++++- python/cuml/test/test_metrics.py | 27 ++++++-- 27 files changed, 675 insertions(+), 144 deletions(-) create mode 100644 cpp/src/metrics/pairwise_distance_correlation.cu create mode 100644 cpp/src/metrics/pairwise_distance_correlation.cuh create mode 100644 cpp/src/metrics/pairwise_distance_hamming.cu create mode 100644 cpp/src/metrics/pairwise_distance_hamming.cuh create mode 100644 cpp/src/metrics/pairwise_distance_jensen_shannon.cu create mode 100644 cpp/src/metrics/pairwise_distance_jensen_shannon.cuh create mode 100644 cpp/src/metrics/pairwise_distance_kl_divergence.cu create mode 100644 cpp/src/metrics/pairwise_distance_kl_divergence.cuh create mode 100644 cpp/src/metrics/pairwise_distance_russell_rao.cu create mode 100644 cpp/src/metrics/pairwise_distance_russell_rao.cuh diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 209cfd29db..ffaefc92f4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -249,11 +249,16 @@ if(BUILD_CUML_CPP_LIBRARY) src/metrics/pairwise_distance.cu src/metrics/pairwise_distance_canberra.cu src/metrics/pairwise_distance_chebyshev.cu + src/metrics/pairwise_distance_correlation.cu src/metrics/pairwise_distance_cosine.cu src/metrics/pairwise_distance_euclidean.cu + src/metrics/pairwise_distance_hamming.cu src/metrics/pairwise_distance_hellinger.cu + src/metrics/pairwise_distance_jensen_shannon.cu + src/metrics/pairwise_distance_kl_divergence.cu src/metrics/pairwise_distance_l1.cu src/metrics/pairwise_distance_minkowski.cu + src/metrics/pairwise_distance_russell_rao.cu src/metrics/r2_score.cu src/metrics/rand_index.cu src/metrics/silhouette_score.cu diff --git a/cpp/src/metrics/pairwise_distance.cu b/cpp/src/metrics/pairwise_distance.cu index 47af2985c4..3a7f89f263 100644 --- a/cpp/src/metrics/pairwise_distance.cu +++ b/cpp/src/metrics/pairwise_distance.cu @@ -22,11 +22,16 @@ #include #include "pairwise_distance_canberra.cuh" #include "pairwise_distance_chebyshev.cuh" +#include "pairwise_distance_correlation.cuh" #include "pairwise_distance_cosine.cuh" #include "pairwise_distance_euclidean.cuh" +#include "pairwise_distance_hamming.cuh" #include "pairwise_distance_hellinger.cuh" +#include "pairwise_distance_jensen_shannon.cuh" +#include "pairwise_distance_kl_divergence.cuh" #include "pairwise_distance_l1.cuh" #include "pairwise_distance_minkowski.cuh" +#include "pairwise_distance_russell_rao.cuh" namespace ML { @@ -50,22 +55,37 @@ void pairwise_distance(const raft::handle_t& handle, pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); break; case raft::distance::DistanceType::CosineExpanded: - pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::L1: - pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::Linf: - pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::HellingerExpanded: - pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::LpUnexpanded: - pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::Canberra: - pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::CorrelationExpanded: + pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::HammingUnexpanded: + pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::JensenShannon: + pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::KLDivergence: + pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::RusselRaoExpanded: + pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; @@ -90,22 +110,37 @@ void pairwise_distance(const raft::handle_t& handle, pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); break; case raft::distance::DistanceType::CosineExpanded: - pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::L1: - pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::Linf: - pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::HellingerExpanded: - pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::LpUnexpanded: - pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::Canberra: - pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::CorrelationExpanded: + pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::HammingUnexpanded: + pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::JensenShannon: + pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::KLDivergence: + pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::RusselRaoExpanded: + pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; diff --git a/cpp/src/metrics/pairwise_distance_canberra.cu b/cpp/src/metrics/pairwise_distance_canberra.cu index fb0520c4bd..4e4ac45857 100644 --- a/cpp/src/metrics/pairwise_distance_canberra.cu +++ b/cpp/src/metrics/pairwise_distance_canberra.cu @@ -15,10 +15,10 @@ * limitations under the License. */ -//#include #include #include #include +#include "pairwise_distance_canberra.cuh" namespace ML { @@ -30,7 +30,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { @@ -38,17 +37,8 @@ void pairwise_distance_canberra(const raft::handle_t& handle, rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function - /* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric, - handle.get_stream(), isRowMajor, - metric_arg);*/ - - switch (metric) { - case raft::distance::DistanceType::Canberra: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_canberra(const raft::handle_t& handle, @@ -58,7 +48,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { @@ -66,17 +55,8 @@ void pairwise_distance_canberra(const raft::handle_t& handle, rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function - /* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric, - handle.get_stream(), isRowMajor, - metric_arg);*/ - - switch (metric) { - case raft::distance::DistanceType::Canberra: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_canberra.cuh b/cpp/src/metrics/pairwise_distance_canberra.cuh index 3d1454cfcc..24bba4906f 100644 --- a/cpp/src/metrics/pairwise_distance_canberra.cuh +++ b/cpp/src/metrics/pairwise_distance_canberra.cuh @@ -30,7 +30,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -41,7 +40,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_chebyshev.cu b/cpp/src/metrics/pairwise_distance_chebyshev.cu index d3bd683c89..917487f60c 100644 --- a/cpp/src/metrics/pairwise_distance_chebyshev.cu +++ b/cpp/src/metrics/pairwise_distance_chebyshev.cu @@ -29,20 +29,14 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { // Allocate workspace rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::Linf: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_chebyshev(const raft::handle_t& handle, @@ -52,20 +46,14 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { // Allocate workspace rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::Linf: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_chebyshev.cuh b/cpp/src/metrics/pairwise_distance_chebyshev.cuh index 6f95dbba30..d8b385808f 100644 --- a/cpp/src/metrics/pairwise_distance_chebyshev.cuh +++ b/cpp/src/metrics/pairwise_distance_chebyshev.cuh @@ -28,7 +28,6 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -39,7 +38,6 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_correlation.cu b/cpp/src/metrics/pairwise_distance_correlation.cu new file mode 100644 index 0000000000..5e972553e4 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_correlation.cu @@ -0,0 +1,65 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "pairwise_distance_correlation.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_correlation(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + bool isRowMajor, + double metric_arg) +{ + // Allocate workspace + rmm::device_uvector workspace(1, handle.get_stream()); + + // Call the distance function + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); +} + +void pairwise_distance_correlation(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + bool isRowMajor, + float metric_arg) +{ + // Allocate workspace + rmm::device_uvector workspace(1, handle.get_stream()); + + // Call the distance function + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_correlation.cuh b/cpp/src/metrics/pairwise_distance_correlation.cuh new file mode 100644 index 0000000000..8db0d59556 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_correlation.cuh @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_correlation(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + bool isRowMajor, + double metric_arg); + +void pairwise_distance_correlation(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + bool isRowMajor, + float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_cosine.cu b/cpp/src/metrics/pairwise_distance_cosine.cu index 5d94fe7a26..058652d139 100644 --- a/cpp/src/metrics/pairwise_distance_cosine.cu +++ b/cpp/src/metrics/pairwise_distance_cosine.cu @@ -30,7 +30,6 @@ void pairwise_distance_cosine(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { @@ -38,14 +37,8 @@ void pairwise_distance_cosine(const raft::handle_t& handle, rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::CosineExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_cosine(const raft::handle_t& handle, @@ -55,20 +48,14 @@ void pairwise_distance_cosine(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { // Allocate workspace rmm::device_uvector workspace(1, handle.get_stream()); - switch (metric) { - case raft::distance::DistanceType::CosineExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_cosine.cuh b/cpp/src/metrics/pairwise_distance_cosine.cuh index 04f07e7de7..58388ea4a9 100644 --- a/cpp/src/metrics/pairwise_distance_cosine.cuh +++ b/cpp/src/metrics/pairwise_distance_cosine.cuh @@ -29,7 +29,6 @@ void pairwise_distance_cosine(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -40,7 +39,6 @@ void pairwise_distance_cosine(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_hamming.cu b/cpp/src/metrics/pairwise_distance_hamming.cu new file mode 100644 index 0000000000..c99cda5479 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_hamming.cu @@ -0,0 +1,65 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "pairwise_distance_hamming.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_hamming(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + bool isRowMajor, + double metric_arg) +{ + // Allocate workspace + rmm::device_uvector workspace(1, handle.get_stream()); + + // Call the distance function + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); +} + +void pairwise_distance_hamming(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + bool isRowMajor, + float metric_arg) +{ + // Allocate workspace + rmm::device_uvector workspace(1, handle.get_stream()); + + // Call the distance function + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_hamming.cuh b/cpp/src/metrics/pairwise_distance_hamming.cuh new file mode 100644 index 0000000000..59b6aad019 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_hamming.cuh @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_hamming(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + bool isRowMajor, + double metric_arg); + +void pairwise_distance_hamming(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + bool isRowMajor, + float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_hellinger.cu b/cpp/src/metrics/pairwise_distance_hellinger.cu index 44c50e57c9..f1e15830d4 100644 --- a/cpp/src/metrics/pairwise_distance_hellinger.cu +++ b/cpp/src/metrics/pairwise_distance_hellinger.cu @@ -30,21 +30,15 @@ void pairwise_distance_hellinger(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { // Allocate workspace rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::HellingerExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_hellinger(const raft::handle_t& handle, @@ -54,21 +48,15 @@ void pairwise_distance_hellinger(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { // Allocate workspace rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::HellingerExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_hellinger.cuh b/cpp/src/metrics/pairwise_distance_hellinger.cuh index 70521b6578..92b820a6a0 100644 --- a/cpp/src/metrics/pairwise_distance_hellinger.cuh +++ b/cpp/src/metrics/pairwise_distance_hellinger.cuh @@ -29,7 +29,6 @@ void pairwise_distance_hellinger(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -40,7 +39,6 @@ void pairwise_distance_hellinger(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu new file mode 100644 index 0000000000..c78a52ffbf --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu @@ -0,0 +1,63 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "pairwise_distance_jensen_shannon.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_jensen_shannon(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + bool isRowMajor, + double metric_arg) +{ + // Allocate workspace + rmm::device_uvector workspace(1, handle.get_stream()); + + // Call the distance function + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); +} + +void pairwise_distance_jensen_shannon(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + bool isRowMajor, + float metric_arg) +{ + // Allocate workspace + rmm::device_uvector workspace(1, handle.get_stream()); + + // Call the distance function + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh b/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh new file mode 100644 index 0000000000..4f6f55af35 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_jensen_shannon(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + bool isRowMajor, + double metric_arg); + +void pairwise_distance_jensen_shannon(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + bool isRowMajor, + float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cu b/cpp/src/metrics/pairwise_distance_kl_divergence.cu new file mode 100644 index 0000000000..2a734145e6 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_kl_divergence.cu @@ -0,0 +1,63 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "pairwise_distance_kl_divergence.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_kl_divergence(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + bool isRowMajor, + double metric_arg) +{ + // Allocate workspace + rmm::device_uvector workspace(1, handle.get_stream()); + + // Call the distance function + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); +} + +void pairwise_distance_kl_divergence(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + bool isRowMajor, + float metric_arg) +{ + // Allocate workspace + rmm::device_uvector workspace(1, handle.get_stream()); + + // Call the distance function + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cuh b/cpp/src/metrics/pairwise_distance_kl_divergence.cuh new file mode 100644 index 0000000000..80125c710b --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_kl_divergence.cuh @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_kl_divergence(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + bool isRowMajor, + double metric_arg); + +void pairwise_distance_kl_divergence(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + bool isRowMajor, + float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_l1.cu b/cpp/src/metrics/pairwise_distance_l1.cu index 1863f582af..80711626a3 100644 --- a/cpp/src/metrics/pairwise_distance_l1.cu +++ b/cpp/src/metrics/pairwise_distance_l1.cu @@ -30,20 +30,14 @@ void pairwise_distance_l1(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { // Allocate workspace rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::L1: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_l1(const raft::handle_t& handle, @@ -53,20 +47,14 @@ void pairwise_distance_l1(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { // Allocate workspace rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::L1: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_l1.cuh b/cpp/src/metrics/pairwise_distance_l1.cuh index f451df5cc8..f93de2bb2d 100644 --- a/cpp/src/metrics/pairwise_distance_l1.cuh +++ b/cpp/src/metrics/pairwise_distance_l1.cuh @@ -28,7 +28,6 @@ void pairwise_distance_l1(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -39,7 +38,6 @@ void pairwise_distance_l1(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_minkowski.cu b/cpp/src/metrics/pairwise_distance_minkowski.cu index 6772edeff2..b7d6b09f4e 100644 --- a/cpp/src/metrics/pairwise_distance_minkowski.cu +++ b/cpp/src/metrics/pairwise_distance_minkowski.cu @@ -30,21 +30,14 @@ void pairwise_distance_minkowski(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { // Allocate workspace rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::LpUnexpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, metric_arg); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, metric_arg); } void pairwise_distance_minkowski(const raft::handle_t& handle, @@ -54,21 +47,14 @@ void pairwise_distance_minkowski(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { // Allocate workspace rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::LpUnexpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, metric_arg); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, metric_arg); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_minkowski.cuh b/cpp/src/metrics/pairwise_distance_minkowski.cuh index 013205e67b..dd0ff59b25 100644 --- a/cpp/src/metrics/pairwise_distance_minkowski.cuh +++ b/cpp/src/metrics/pairwise_distance_minkowski.cuh @@ -29,7 +29,6 @@ void pairwise_distance_minkowski(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -40,7 +39,6 @@ void pairwise_distance_minkowski(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cu b/cpp/src/metrics/pairwise_distance_russell_rao.cu new file mode 100644 index 0000000000..3b73a89c01 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_russell_rao.cu @@ -0,0 +1,65 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "pairwise_distance_russell_rao.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_russell_rao(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + bool isRowMajor, + double metric_arg) +{ + // Allocate workspace + rmm::device_uvector workspace(1, handle.get_stream()); + + // Call the distance function + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); +} + +void pairwise_distance_russell_rao(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + bool isRowMajor, + float metric_arg) +{ + // Allocate workspace + rmm::device_uvector workspace(1, handle.get_stream()); + + // Call the distance function + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cuh b/cpp/src/metrics/pairwise_distance_russell_rao.cuh new file mode 100644 index 0000000000..1d25194f42 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_russell_rao.cuh @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_russell_rao(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + bool isRowMajor, + double metric_arg); + +void pairwise_distance_russell_rao(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + bool isRowMajor, + float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/python/cuml/metrics/distance_type.pxd b/python/cuml/metrics/distance_type.pxd index 93cf1ad9e9..4286ea1c9d 100644 --- a/python/cuml/metrics/distance_type.pxd +++ b/python/cuml/metrics/distance_type.pxd @@ -33,5 +33,8 @@ cdef extern from "raft/linalg/distance_type.h" namespace "raft::distance": Haversine "raft::distance::DistanceType::Haversine" BrayCurtis "raft::distance::DistanceType::BrayCurtis" JensenShannon "raft::distance::DistanceType::JensenShannon" + HammingUnexpanded "raft::distance::DistanceType::HammingUnexpanded" + KLDivergence "raft::distance::DistanceType::KLDivergence" + RusselRaoExpanded "raft::distance::DistanceType::RusselRaoExpanded" DiceExpanded "raft::distance::DistanceType::DiceExpanded" Precomputed "raft::distance::DistanceType::Precomputed" diff --git a/python/cuml/metrics/pairwise_distances.pyx b/python/cuml/metrics/pairwise_distances.pyx index d43fcb4329..29f9d492ac 100644 --- a/python/cuml/metrics/pairwise_distances.pyx +++ b/python/cuml/metrics/pairwise_distances.pyx @@ -73,7 +73,12 @@ PAIRWISE_DISTANCE_METRICS = { "canberra": DistanceType.Canberra, "chebyshev": DistanceType.Linf, "minkowski": DistanceType.LpUnexpanded, - "hellinger": DistanceType.HellingerExpanded + "hellinger": DistanceType.HellingerExpanded, + "correlation": DistanceType.CorrelationExpanded, + "jensenshannon": DistanceType.JensenShannon, + "hamming": DistanceType.HammingUnexpanded, + "kldivergence": DistanceType.KLDivergence, + "russellrao": DistanceType.RusselRaoExpanded } PAIRWISE_DISTANCE_SPARSE_METRICS = { @@ -217,6 +222,11 @@ def pairwise_distances(X, Y=None, metric="euclidean", handle=None, handle = Handle() if handle is None else handle cdef handle_t *handle_ = handle.getHandle() + if metric in ['russellrao'] and not np.all(X.data == 1.): + warnings.warn("X was converted to boolean for metric {}" + .format(metric)) + X = np.where(X != 0., 1.0, 0.0) + # Get the input arrays, preserve order and type where possible X_m, n_samples_x, n_features_x, dtype_x = \ input_to_cuml_array(X, order="K", check_dtype=[np.float32, np.float64]) @@ -235,12 +245,16 @@ def pairwise_distances(X, Y=None, metric="euclidean", handle=None, if (n_samples_x == 1 or n_features_x == 1): input_order = "K" + if metric in ['russellrao'] and not np.all(Y.data == 1.): + warnings.warn("Y was converted to boolean for metric {}" + .format(metric)) + Y = np.where(Y != 0., 1.0, 0.0) + Y_m, n_samples_y, n_features_y, dtype_y = \ input_to_cuml_array(Y, order=input_order, convert_to_dtype=(dtype_x if convert_dtype else None), check_dtype=[dtype_x]) - # Get the order from Y if necessary (It's possible to set order="F" in # input_to_cuml_array and have Y_m.order=="C") if (input_order == "K"): diff --git a/python/cuml/test/test_metrics.py b/python/cuml/test/test_metrics.py index bb18ecddd4..028b2897c8 100644 --- a/python/cuml/test/test_metrics.py +++ b/python/cuml/test/test_metrics.py @@ -77,6 +77,8 @@ from cuml.metrics import pairwise_distances, sparse_pairwise_distances, \ PAIRWISE_DISTANCE_METRICS, PAIRWISE_DISTANCE_SPARSE_METRICS from sklearn.metrics import pairwise_distances as sklearn_pairwise_distances +from scipy.spatial import distance as scipy_pairwise_distances +from scipy.special import rel_entr as scipy_kl_divergence @pytest.fixture(scope='module') @@ -862,19 +864,29 @@ def test_log_loss_at_limits(): log_loss(y_true, y_pred) -def ref_dense_pairwise_dist(X, Y=None, metric=None): +def naive_kl_divergence_dist(X, Y): + return 0.5 * np.array([[np.sum(np.where(yj != 0, + scipy_kl_divergence(xi, yj), 0.0)) for yj in Y] + for xi in X]) + + +def ref_dense_pairwise_dist(X, Y=None, metric=None, convert_dtype=False): # Select sklearn except for Hellinger that # sklearn doesn't support if Y is None: Y = X if metric == "hellinger": return naive_hellinger(X, Y) + elif metric == "jensenshannon": + return scipy_pairwise_distances.cdist(X, Y, 'jensenshannon') + elif metric == "kldivergence": + return naive_kl_divergence_dist(X, Y) else: return sklearn_pairwise_distances(X, Y, metric) def prep_dense_array(array, metric, col_major=0): - if metric == "hellinger": + if metric in ['hellinger', 'jensenshannon', 'kldivergence']: norm_array = preprocessing.normalize(array, norm="l1") return np.asfortranarray(norm_array) if col_major else norm_array else: @@ -935,11 +947,12 @@ def test_pairwise_distances(metric: str, matrix_size, is_col_major): cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Test sending an int type with convert_dtype=True - Y = prep_dense_array(rng.randint(10, size=Y.shape), - metric=metric, col_major=is_col_major) - S = pairwise_distances(X, Y, metric=metric, convert_dtype=True) - S2 = ref_dense_pairwise_dist(X, Y, metric=metric) - cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) + if metric != 'kldivergence': + Y = prep_dense_array(rng.randint(10, size=Y.shape), + metric=metric, col_major=is_col_major) + S = pairwise_distances(X, Y, metric=metric, convert_dtype=True) + S2 = ref_dense_pairwise_dist(X, Y, metric=metric, convert_dtype=True) + cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Test that uppercase on the metric name throws an error. with pytest.raises(ValueError):