From a52672eef3f7af7aa6b9428b05e709248e757e41 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Thu, 8 Jul 2021 23:44:16 +0530 Subject: [PATCH] Use chebyshev, canberra, hellinger and minkowski distance metrics (#3990) This PR relies on RAFT PR https://github.com/rapidsai/raft/pull/276 which adds these distance metrics support. Authors: - Mahesh Doijade (https://github.com/mdoijade) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) - Corey J. Nolet (https://github.com/cjnolet) - AJ Schmidt (https://github.com/ajschmidt8) URL: https://github.com/rapidsai/cuml/pull/3990 --- ci/checks/style.sh | 1 + cpp/CMakeLists.txt | 8 ++ cpp/include/cuml/metrics/metrics.hpp | 6 +- cpp/src/hierarchy/pw_dist_graph.cuh | 12 +-- cpp/src/kmeans/common.cuh | 9 +- cpp/src/metrics/pairwise_distance.cu | 97 ++++++++++++++++--- cpp/src/metrics/pairwise_distance_canberra.cu | 74 ++++++++++++++ .../metrics/pairwise_distance_canberra.cuh | 37 +++++++ .../metrics/pairwise_distance_chebyshev.cu | 63 ++++++++++++ .../metrics/pairwise_distance_chebyshev.cuh | 35 +++++++ cpp/src/metrics/pairwise_distance_cosine.cu | 64 ++++++++++++ cpp/src/metrics/pairwise_distance_cosine.cuh | 36 +++++++ .../metrics/pairwise_distance_euclidean.cu | 96 ++++++++++++++++++ .../metrics/pairwise_distance_euclidean.cuh | 34 +++++++ .../metrics/pairwise_distance_hellinger.cu | 64 ++++++++++++ .../metrics/pairwise_distance_hellinger.cuh | 35 +++++++ cpp/src/metrics/pairwise_distance_l1.cu | 64 ++++++++++++ cpp/src/metrics/pairwise_distance_l1.cuh | 35 +++++++ .../metrics/pairwise_distance_minkowski.cu | 66 +++++++++++++ .../metrics/pairwise_distance_minkowski.cuh | 36 +++++++ cpp/src/metrics/silhouette_score.cu | 4 +- .../metrics/batched/silhouette_score.cuh | 14 +-- cpp/src_prims/metrics/silhouette_score.cuh | 11 ++- .../metrics/trustworthiness_score.cuh | 10 +- cpp/test/prims/silhouette_score.cu | 11 +-- cpp/test/sg/dbscan_test.cu | 10 +- cpp/test/sg/rproj_test.cu | 18 ++-- python/cuml/metrics/pairwise_distances.pyx | 20 ++-- python/cuml/test/test_metrics.py | 89 +++++++++++------ 29 files changed, 944 insertions(+), 115 deletions(-) create mode 100644 cpp/src/metrics/pairwise_distance_canberra.cu create mode 100644 cpp/src/metrics/pairwise_distance_canberra.cuh create mode 100644 cpp/src/metrics/pairwise_distance_chebyshev.cu create mode 100644 cpp/src/metrics/pairwise_distance_chebyshev.cuh create mode 100644 cpp/src/metrics/pairwise_distance_cosine.cu create mode 100644 cpp/src/metrics/pairwise_distance_cosine.cuh create mode 100644 cpp/src/metrics/pairwise_distance_euclidean.cu create mode 100644 cpp/src/metrics/pairwise_distance_euclidean.cuh create mode 100644 cpp/src/metrics/pairwise_distance_hellinger.cu create mode 100644 cpp/src/metrics/pairwise_distance_hellinger.cuh create mode 100644 cpp/src/metrics/pairwise_distance_l1.cu create mode 100644 cpp/src/metrics/pairwise_distance_l1.cuh create mode 100644 cpp/src/metrics/pairwise_distance_minkowski.cu create mode 100644 cpp/src/metrics/pairwise_distance_minkowski.cuh diff --git a/ci/checks/style.sh b/ci/checks/style.sh index fbe78632d1..9dd4e58c09 100644 --- a/ci/checks/style.sh +++ b/ci/checks/style.sh @@ -14,6 +14,7 @@ cd $WORKSPACE export GIT_DESCRIBE_TAG=`git describe --tags` export MINOR_VERSION=`echo $GIT_DESCRIBE_TAG | grep -o -E '([0-9]+\.[0-9]+)'` conda install "ucx-py=0.21.*" "ucx-proc=*=gpu" +conda install -c conda-forge clang=8.0.1 clang-tools=8.0.1 # Run flake8 and get results/return code FLAKE=`flake8 --config=python/setup.cfg` diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 578a1c866a..60e30aa876 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -252,6 +252,13 @@ if(BUILD_CUML_CPP_LIBRARY) src/metrics/kl_divergence.cu src/metrics/mutual_info_score.cu src/metrics/pairwise_distance.cu + src/metrics/pairwise_distance_canberra.cu + src/metrics/pairwise_distance_chebyshev.cu + src/metrics/pairwise_distance_cosine.cu + src/metrics/pairwise_distance_euclidean.cu + src/metrics/pairwise_distance_hellinger.cu + src/metrics/pairwise_distance_l1.cu + src/metrics/pairwise_distance_minkowski.cu src/metrics/r2_score.cu src/metrics/rand_index.cu src/metrics/silhouette_score.cu @@ -323,6 +330,7 @@ if(BUILD_CUML_CPP_LIBRARY) $:${cumlprims_mg_INCLUDE_DIRS}>> PRIVATE $ + $ $ $<$,$>:${NCCL_INCLUDE_DIRS}> $<$:${MPI_CXX_INCLUDE_PATH}> diff --git a/cpp/include/cuml/metrics/metrics.hpp b/cpp/include/cuml/metrics/metrics.hpp index 5afa93dc5d..2f692c48c8 100644 --- a/cpp/include/cuml/metrics/metrics.hpp +++ b/cpp/include/cuml/metrics/metrics.hpp @@ -300,11 +300,12 @@ float accuracy_score_py(const raft::handle_t &handle, const int *predictions, * @param metric the distance metric to use for the calculation * @param isRowMajor specifies whether the x and y data pointers are row (C * type array) or col (F type array) major + * @param metric_arg the value of `p` for Minkowski (l-p) distances. */ void pairwise_distance(const raft::handle_t &handle, const double *x, const double *y, double *dist, int m, int n, int k, raft::distance::DistanceType metric, - bool isRowMajor = true); + bool isRowMajor = true, double metric_arg = 2.0); /** * @brief Calculates the ij pairwise distances between two input arrays of float type @@ -320,11 +321,12 @@ void pairwise_distance(const raft::handle_t &handle, const double *x, * @param metric the distance metric to use for the calculation * @param isRowMajor specifies whether the x and y data pointers are row (C * type array) or col (F type array) major + * @param metric_arg the value of `p` for Minkowski (l-p) distances. */ void pairwise_distance(const raft::handle_t &handle, const float *x, const float *y, float *dist, int m, int n, int k, raft::distance::DistanceType metric, - bool isRowMajor = true); + bool isRowMajor = true, float metric_arg = 2.0f); void pairwiseDistance_sparse(const raft::handle_t &handle, double *x, double *y, double *dist, int x_nrows, int y_nrows, int n_cols, diff --git a/cpp/src/hierarchy/pw_dist_graph.cuh b/cpp/src/hierarchy/pw_dist_graph.cuh index 0e16d9b02c..a1f6dff9ae 100644 --- a/cpp/src/hierarchy/pw_dist_graph.cuh +++ b/cpp/src/hierarchy/pw_dist_graph.cuh @@ -21,7 +21,7 @@ #include -#include +#include #include #include @@ -70,7 +70,6 @@ template void pairwise_distances(const raft::handle_t &handle, const value_t *X, size_t m, size_t n, raft::distance::DistanceType metric, value_idx *indptr, value_idx *indices, value_t *data) { - auto d_alloc = handle.get_device_allocator(); auto stream = handle.get_stream(); auto exec_policy = rmm::exec_policy(stream); @@ -83,16 +82,10 @@ void pairwise_distances(const raft::handle_t &handle, const value_t *X, raft::update_device(indptr + m, &nnz, 1, stream); - // TODO: Keeping raft device buffer here for now until our - // dense pairwise distances API is finished being refactored - raft::mr::device::buffer workspace(d_alloc, stream, (size_t)0); - // TODO: It would ultimately be nice if the MST could accept // dense inputs directly so we don't need to double the memory // usage to hand it a sparse array here. - raft::distance::pairwise_distance( - X, X, data, m, m, n, workspace, metric, stream); - + ML::Metrics::pairwise_distance(handle, X, X, data, m, m, n, metric); // self-loops get max distance auto transform_in = thrust::make_zip_iterator( thrust::make_tuple(thrust::make_counting_iterator(0), data)); @@ -120,7 +113,6 @@ struct distance_graph_impl &indptr, rmm::device_uvector &indices, rmm::device_uvector &data, int c) { - auto d_alloc = handle.get_device_allocator(); auto stream = handle.get_stream(); size_t nnz = m * m; diff --git a/cpp/src/kmeans/common.cuh b/cpp/src/kmeans/common.cuh index 1ef1c08d35..2d9ff29296 100644 --- a/cpp/src/kmeans/common.cuh +++ b/cpp/src/kmeans/common.cuh @@ -15,10 +15,10 @@ */ #pragma once +#include #include #include #include -#include #include #include @@ -256,9 +256,10 @@ void pairwise_distance(const raft::handle_t &handle, ASSERT(X.getSize(1) == centroids.getSize(1), "# features in dataset and centroids are different (must be same)"); - raft::distance::pairwise_distance( - X.data(), centroids.data(), pairwiseDistance.data(), n_samples, n_clusters, - n_features, workspace, metric, stream); + + ML::Metrics::pairwise_distance(handle, X.data(), centroids.data(), + pairwiseDistance.data(), n_samples, n_clusters, + n_features, metric); } // Calculates a pair for every sample in input 'X' where key is an diff --git a/cpp/src/metrics/pairwise_distance.cu b/cpp/src/metrics/pairwise_distance.cu index 5850122b9a..aac676633a 100644 --- a/cpp/src/metrics/pairwise_distance.cu +++ b/cpp/src/metrics/pairwise_distance.cu @@ -20,32 +20,97 @@ #include #include #include +#include "pairwise_distance_canberra.cuh" +#include "pairwise_distance_chebyshev.cuh" +#include "pairwise_distance_cosine.cuh" +#include "pairwise_distance_euclidean.cuh" +#include "pairwise_distance_hellinger.cuh" +#include "pairwise_distance_l1.cuh" +#include "pairwise_distance_minkowski.cuh" namespace ML { namespace Metrics { void pairwise_distance(const raft::handle_t &handle, const double *x, const double *y, double *dist, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor) { - //Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), - handle.get_stream(), 1); - - //Call the distance function - raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric, - handle.get_stream(), isRowMajor); + raft::distance::DistanceType metric, bool isRowMajor, + double metric_arg) { + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, + isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::CosineExpanded: + pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor, + metric_arg); + break; + case raft::distance::DistanceType::L1: + pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor, + metric_arg); + break; + case raft::distance::DistanceType::Linf: + pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric, + isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::HellingerExpanded: + pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric, + isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::LpUnexpanded: + pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric, + isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::Canberra: + pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, + isRowMajor, metric_arg); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + }; } void pairwise_distance(const raft::handle_t &handle, const float *x, const float *y, float *dist, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor) { - //Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), - handle.get_stream(), 1); - - //Call the distance function - raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric, - handle.get_stream(), isRowMajor); + raft::distance::DistanceType metric, bool isRowMajor, + float metric_arg) { + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, + isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::CosineExpanded: + pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor, + metric_arg); + break; + case raft::distance::DistanceType::L1: + pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor, + metric_arg); + break; + case raft::distance::DistanceType::Linf: + pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric, + isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::HellingerExpanded: + pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric, + isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::LpUnexpanded: + pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric, + isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::Canberra: + pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, + isRowMajor, metric_arg); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + }; } template diff --git a/cpp/src/metrics/pairwise_distance_canberra.cu b/cpp/src/metrics/pairwise_distance_canberra.cu new file mode 100644 index 0000000000..6d600f46fe --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_canberra.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//#include +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_canberra(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + + //Call the distance function + /* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric, + handle.get_stream(), isRowMajor, + metric_arg);*/ + + switch (metric) { + case raft::distance::DistanceType::Canberra: + raft::distance::pairwise_distance_impl< + double, int, raft::distance::DistanceType::Canberra>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_canberra(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + + //Call the distance function + /* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric, + handle.get_stream(), isRowMajor, + metric_arg);*/ + + switch (metric) { + case raft::distance::DistanceType::Canberra: + raft::distance::pairwise_distance_impl< + float, int, raft::distance::DistanceType::Canberra>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_canberra.cuh b/cpp/src/metrics/pairwise_distance_canberra.cuh new file mode 100644 index 0000000000..390be9af85 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_canberra.cuh @@ -0,0 +1,37 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_canberra(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg); + +void pairwise_distance_canberra(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_chebyshev.cu b/cpp/src/metrics/pairwise_distance_chebyshev.cu new file mode 100644 index 0000000000..cdcea7e185 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_chebyshev.cu @@ -0,0 +1,63 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "pairwise_distance_chebyshev.cuh" +namespace ML { + +namespace Metrics { +void pairwise_distance_chebyshev(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + //Call the distance function + switch (metric) { + case raft::distance::DistanceType::Linf: + raft::distance::pairwise_distance_impl< + double, int, raft::distance::DistanceType::Linf>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_chebyshev(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + //Call the distance function + switch (metric) { + case raft::distance::DistanceType::Linf: + raft::distance::pairwise_distance_impl< + float, int, raft::distance::DistanceType::Linf>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_chebyshev.cuh b/cpp/src/metrics/pairwise_distance_chebyshev.cuh new file mode 100644 index 0000000000..cd45f2d721 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_chebyshev.cuh @@ -0,0 +1,35 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_chebyshev(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg); + +void pairwise_distance_chebyshev(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_cosine.cu b/cpp/src/metrics/pairwise_distance_cosine.cu new file mode 100644 index 0000000000..b9fbca1ef5 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_cosine.cu @@ -0,0 +1,64 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "pairwise_distance_cosine.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_cosine(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + + //Call the distance function + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + raft::distance::pairwise_distance_impl< + double, int, raft::distance::DistanceType::CosineExpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_cosine(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, int k, + raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + raft::distance::pairwise_distance_impl< + float, int, raft::distance::DistanceType::CosineExpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_cosine.cuh b/cpp/src/metrics/pairwise_distance_cosine.cuh new file mode 100644 index 0000000000..ad5a2fbf62 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_cosine.cuh @@ -0,0 +1,36 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_cosine(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg); + +void pairwise_distance_cosine(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, int k, + raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_euclidean.cu b/cpp/src/metrics/pairwise_distance_euclidean.cu new file mode 100644 index 0000000000..d03af7b93e --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_euclidean.cu @@ -0,0 +1,96 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "pairwise_distance_euclidean.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_euclidean(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + + //Call the distance function + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + raft::distance::pairwise_distance_impl< + double, int, raft::distance::DistanceType::L2Expanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::L2SqrtExpanded: + raft::distance::pairwise_distance_impl< + double, int, raft::distance::DistanceType::L2SqrtExpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::L2Unexpanded: + raft::distance::pairwise_distance_impl< + double, int, raft::distance::DistanceType::L2Unexpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::L2SqrtUnexpanded: + raft::distance::pairwise_distance_impl< + double, int, raft::distance::DistanceType::L2SqrtUnexpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_euclidean(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + + //Call the distance function + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + raft::distance::pairwise_distance_impl< + float, int, raft::distance::DistanceType::L2Expanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::L2SqrtExpanded: + raft::distance::pairwise_distance_impl< + float, int, raft::distance::DistanceType::L2SqrtExpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::L2Unexpanded: + raft::distance::pairwise_distance_impl< + float, int, raft::distance::DistanceType::L2Unexpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::L2SqrtUnexpanded: + raft::distance::pairwise_distance_impl< + float, int, raft::distance::DistanceType::L2SqrtUnexpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_euclidean.cuh b/cpp/src/metrics/pairwise_distance_euclidean.cuh new file mode 100644 index 0000000000..447445e726 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_euclidean.cuh @@ -0,0 +1,34 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_euclidean(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg); + +void pairwise_distance_euclidean(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg); +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_hellinger.cu b/cpp/src/metrics/pairwise_distance_hellinger.cu new file mode 100644 index 0000000000..a3c26699f0 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_hellinger.cu @@ -0,0 +1,64 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "pairwise_distance_hellinger.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_hellinger(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + //Call the distance function + switch (metric) { + case raft::distance::DistanceType::HellingerExpanded: + raft::distance::pairwise_distance_impl< + double, int, raft::distance::DistanceType::HellingerExpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_hellinger(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + //Call the distance function + switch (metric) { + case raft::distance::DistanceType::HellingerExpanded: + raft::distance::pairwise_distance_impl< + float, int, raft::distance::DistanceType::HellingerExpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_hellinger.cuh b/cpp/src/metrics/pairwise_distance_hellinger.cuh new file mode 100644 index 0000000000..0359993bc8 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_hellinger.cuh @@ -0,0 +1,35 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_hellinger(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg); + +void pairwise_distance_hellinger(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg); +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_l1.cu b/cpp/src/metrics/pairwise_distance_l1.cu new file mode 100644 index 0000000000..1179ce9283 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_l1.cu @@ -0,0 +1,64 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "pairwise_distance_l1.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_l1(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, int k, + raft::distance::DistanceType metric, bool isRowMajor, + double metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + //Call the distance function + switch (metric) { + case raft::distance::DistanceType::L1: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_l1(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, int k, + raft::distance::DistanceType metric, bool isRowMajor, + float metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + //Call the distance function + switch (metric) { + case raft::distance::DistanceType::L1: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_l1.cuh b/cpp/src/metrics/pairwise_distance_l1.cuh new file mode 100644 index 0000000000..f1470cb6ed --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_l1.cuh @@ -0,0 +1,35 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_l1(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, int k, + raft::distance::DistanceType metric, bool isRowMajor, + double metric_arg); + +void pairwise_distance_l1(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, int k, + raft::distance::DistanceType metric, bool isRowMajor, + float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_minkowski.cu b/cpp/src/metrics/pairwise_distance_minkowski.cu new file mode 100644 index 0000000000..af7938d618 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_minkowski.cu @@ -0,0 +1,66 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "pairwise_distance_minkowski.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_minkowski(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + //Call the distance function + switch (metric) { + case raft::distance::DistanceType::LpUnexpanded: + raft::distance::pairwise_distance_impl< + double, int, raft::distance::DistanceType::LpUnexpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, + metric_arg); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_minkowski(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg) { + //Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), + handle.get_stream(), 1); + //Call the distance function + switch (metric) { + case raft::distance::DistanceType::LpUnexpanded: + raft::distance::pairwise_distance_impl< + float, int, raft::distance::DistanceType::LpUnexpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, + metric_arg); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_minkowski.cuh b/cpp/src/metrics/pairwise_distance_minkowski.cuh new file mode 100644 index 0000000000..3a0a06c1df --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_minkowski.cuh @@ -0,0 +1,36 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_minkowski(const raft::handle_t &handle, const double *x, + const double *y, double *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, double metric_arg); + +void pairwise_distance_minkowski(const raft::handle_t &handle, const float *x, + const float *y, float *dist, int m, int n, + int k, raft::distance::DistanceType metric, + bool isRowMajor, float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/silhouette_score.cu b/cpp/src/metrics/silhouette_score.cu index b89720dc91..b6d2c80cdc 100644 --- a/cpp/src/metrics/silhouette_score.cu +++ b/cpp/src/metrics/silhouette_score.cu @@ -27,8 +27,8 @@ double silhouette_score(const raft::handle_t &handle, double *y, int nRows, int nCols, int *labels, int nLabels, double *silScores, raft::distance::DistanceType metric) { return MLCommon::Metrics::silhouette_score( - y, nRows, nCols, labels, nLabels, silScores, handle.get_device_allocator(), - handle.get_stream(), metric); + handle, y, nRows, nCols, labels, nLabels, silScores, + handle.get_device_allocator(), handle.get_stream(), metric); } namespace Batched { diff --git a/cpp/src_prims/metrics/batched/silhouette_score.cuh b/cpp/src_prims/metrics/batched/silhouette_score.cuh index 879f68b1b6..cfdb258b2d 100644 --- a/cpp/src_prims/metrics/batched/silhouette_score.cuh +++ b/cpp/src_prims/metrics/batched/silhouette_score.cuh @@ -16,9 +16,9 @@ #pragma once -#include "../silhouette_score.cuh" - #include +#include +#include "../silhouette_score.cuh" #include #include @@ -134,15 +134,11 @@ rmm::device_uvector get_pairwise_distance( const raft::handle_t &handle, value_t *left_begin, value_t *right_begin, value_idx &n_left_rows, value_idx &n_right_rows, value_idx &n_cols, raft::distance::DistanceType metric, cudaStream_t stream) { - auto allocator = handle.get_device_allocator(); - - MLCommon::device_buffer workspace(allocator, stream, 1); - rmm::device_uvector distances(n_left_rows * n_right_rows, stream); - raft::distance::pairwise_distance(left_begin, right_begin, distances.data(), - n_left_rows, n_right_rows, n_cols, - workspace, metric, stream); + ML::Metrics::pairwise_distance(handle, left_begin, right_begin, + distances.data(), n_left_rows, n_right_rows, + n_cols, metric); return distances; } diff --git a/cpp/src_prims/metrics/silhouette_score.cuh b/cpp/src_prims/metrics/silhouette_score.cuh index 1191a6cd74..47c711e3f0 100644 --- a/cpp/src_prims/metrics/silhouette_score.cuh +++ b/cpp/src_prims/metrics/silhouette_score.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -176,8 +177,9 @@ struct MinOp { * @param metric: the numerical value that maps to the type of distance metric to be used in the calculations */ template -DataT silhouette_score(DataT *X_in, int nRows, int nCols, LabelT *labels, - int nLabels, DataT *silhouette_scorePerSample, +DataT silhouette_score(const raft::handle_t &handle, DataT *X_in, int nRows, + int nCols, LabelT *labels, int nLabels, + DataT *silhouette_scorePerSample, std::shared_ptr allocator, cudaStream_t stream, raft::distance::DistanceType metric = @@ -190,9 +192,8 @@ DataT silhouette_score(DataT *X_in, int nRows, int nCols, LabelT *labels, nRows * nRows); MLCommon::device_buffer workspace(allocator, stream, 1); - raft::distance::pairwise_distance( - X_in, X_in, distanceMatrix.data(), nRows, nRows, nCols, workspace, - static_cast(metric), stream); + ML::Metrics::pairwise_distance(handle, X_in, X_in, distanceMatrix.data(), + nRows, nRows, nCols, metric); //deciding on the array of silhouette scores for each dataPoint MLCommon::device_buffer silhouette_scoreSamples(allocator, stream, 0); diff --git a/cpp/src_prims/metrics/trustworthiness_score.cuh b/cpp/src_prims/metrics/trustworthiness_score.cuh index 9c27730f6e..030f4b8fa8 100644 --- a/cpp/src_prims/metrics/trustworthiness_score.cuh +++ b/cpp/src_prims/metrics/trustworthiness_score.cuh @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #include #include #include @@ -132,12 +132,8 @@ double trustworthiness_score(const raft::handle_t &h, const math_t *X, int curBatchSize = min(toDo, batchSize); // Takes at most batchSize vectors at a time - - size_t workspaceSize = 0; - - raft::distance::distance( - &X[(n - toDo) * m], X, X_dist.data(), curBatchSize, n, m, (void *)nullptr, - workspaceSize, stream); + ML::Metrics::pairwise_distance(h, &X[(n - toDo) * m], X, X_dist.data(), + curBatchSize, n, m, distance_type); size_t colSortWorkspaceSize = 0; bool bAllocWorkspace = false; diff --git a/cpp/test/prims/silhouette_score.cu b/cpp/test/prims/silhouette_score.cu index fe0a2f7923..9b78204a49 100644 --- a/cpp/test/prims/silhouette_score.cu +++ b/cpp/test/prims/silhouette_score.cu @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -67,13 +68,11 @@ class silhouetteScoreTest //finding the distance matrix device_buffer d_distanceMatrix(allocator, stream, nRows * nRows); - device_buffer workspace(allocator, stream, 1); double *h_distanceMatrix = (double *)malloc(nRows * nRows * sizeof(double *)); - raft::distance::pairwise_distance(d_X, d_X, d_distanceMatrix.data(), nRows, - nRows, nCols, workspace, params.metric, - stream); + ML::Metrics::pairwise_distance(handle, d_X, d_X, d_distanceMatrix.data(), + nRows, nRows, nCols, params.metric); CUDA_CHECK(cudaStreamSynchronize(stream)); @@ -171,8 +170,8 @@ class silhouetteScoreTest //calling the silhouette_score CUDA implementation computedSilhouetteScore = MLCommon::Metrics::silhouette_score( - d_X, nRows, nCols, d_labels, nLabels, sampleSilScore, allocator, stream, - params.metric); + handle, d_X, nRows, nCols, d_labels, nLabels, sampleSilScore, allocator, + stream, params.metric); batchedSilhouetteScore = Batched::silhouette_score(handle, d_X, nRows, nCols, d_labels, nLabels, diff --git a/cpp/test/sg/dbscan_test.cu b/cpp/test/sg/dbscan_test.cu index 197b1b33b3..adc9b808b1 100644 --- a/cpp/test/sg/dbscan_test.cu +++ b/cpp/test/sg/dbscan_test.cu @@ -88,13 +88,9 @@ class DbscanTest : public ::testing::TestWithParam> { true, -10.0f, 10.0f, params.seed); if (params.metric == raft::distance::Precomputed) { - device_buffer workspace(handle.get_device_allocator(), - handle.get_stream(), 0); - - raft::distance::pairwise_distance_impl( - out.data(), out.data(), dist.data(), params.n_row, params.n_row, - params.n_col, workspace, handle.get_stream(), true); + ML::Metrics::pairwise_distance( + handle, out.data(), out.data(), dist.data(), params.n_row, params.n_row, + params.n_col, raft::distance::L2SqrtUnexpanded); } raft::allocate(labels, params.n_row); diff --git a/cpp/test/sg/rproj_test.cu b/cpp/test/sg/rproj_test.cu index c35e333ac5..9448ca4f9f 100644 --- a/cpp/test/sg/rproj_test.cu +++ b/cpp/test/sg/rproj_test.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -144,14 +145,11 @@ class RPROJTest : public ::testing::Test { constexpr auto distance_type = raft::distance::DistanceType::L2SqrtUnexpanded; - size_t workspaceSize = 0; T* d_pdist; raft::allocate(d_pdist, N * N); - - raft::distance::distance( - d_input, d_input, d_pdist, N, N, M, (void*)nullptr, workspaceSize, - h.get_stream()); + ML::Metrics::pairwise_distance(h, d_input, d_input, d_pdist, N, N, M, + distance_type); CUDA_CHECK(cudaPeekAtLastError()); T* h_pdist = new T[N * N]; @@ -160,9 +158,8 @@ class RPROJTest : public ::testing::Test { T* d_pdist1; raft::allocate(d_pdist1, N * N); - raft::distance::distance( - d_output1, d_output1, d_pdist1, N, N, D, (void*)nullptr, workspaceSize, - h.get_stream()); + ML::Metrics::pairwise_distance(h, d_output1, d_output1, d_pdist1, N, N, D, + distance_type); CUDA_CHECK(cudaPeekAtLastError()); T* h_pdist1 = new T[N * N]; @@ -171,9 +168,8 @@ class RPROJTest : public ::testing::Test { T* d_pdist2; raft::allocate(d_pdist2, N * N); - raft::distance::distance( - d_output2, d_output2, d_pdist2, N, N, D, (void*)nullptr, workspaceSize, - h.get_stream()); + ML::Metrics::pairwise_distance(h, d_output2, d_output2, d_pdist2, N, N, D, + distance_type); CUDA_CHECK(cudaPeekAtLastError()); T* h_pdist2 = new T[N * N]; diff --git a/python/cuml/metrics/pairwise_distances.pyx b/python/cuml/metrics/pairwise_distances.pyx index 0ca579f9bb..d43fcb4329 100644 --- a/python/cuml/metrics/pairwise_distances.pyx +++ b/python/cuml/metrics/pairwise_distances.pyx @@ -38,10 +38,12 @@ from cuml.metrics.distance_type cimport DistanceType cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics": void pairwise_distance(const handle_t &handle, const double *x, const double *y, double *dist, int m, int n, int k, - DistanceType metric, bool isRowMajor) except + + DistanceType metric, bool isRowMajor, + double metric_arg) except + void pairwise_distance(const handle_t &handle, const float *x, const float *y, float *dist, int m, int n, int k, - DistanceType metric, bool isRowMajor) except + + DistanceType metric, bool isRowMajor, + float metric_arg) except + void pairwiseDistance_sparse(const handle_t &handle, float *x, float *y, float *dist, int x_nrows, int y_nrows, int n_cols, int x_nnz, int y_nnz, @@ -67,7 +69,11 @@ PAIRWISE_DISTANCE_METRICS = { "l1": DistanceType.L1, "l2": DistanceType.L2SqrtUnexpanded, "manhattan": DistanceType.L1, - "sqeuclidean": DistanceType.L2Expanded + "sqeuclidean": DistanceType.L2Expanded, + "canberra": DistanceType.Canberra, + "chebyshev": DistanceType.Linf, + "minkowski": DistanceType.LpUnexpanded, + "hellinger": DistanceType.HellingerExpanded } PAIRWISE_DISTANCE_SPARSE_METRICS = { @@ -127,7 +133,7 @@ def _determine_metric(metric_str, is_sparse=False): @cuml.internals.api_return_array(get_output_type=True) def pairwise_distances(X, Y=None, metric="euclidean", handle=None, - convert_dtype=True, **kwds): + convert_dtype=True, metric_arg=2, **kwds): """ Compute the distance matrix from a vector array `X` and optional `Y`. @@ -275,7 +281,8 @@ def pairwise_distances(X, Y=None, metric="euclidean", handle=None, n_samples_y, n_features_x, metric_val, - is_row_major) + is_row_major, + metric_arg) elif (dtype_x == np.float64): pairwise_distance(handle_[0], d_X_ptr, @@ -285,7 +292,8 @@ def pairwise_distances(X, Y=None, metric="euclidean", handle=None, n_samples_y, n_features_x, metric_val, - is_row_major) + is_row_major, + metric_arg) else: raise NotImplementedError("Unsupported dtype: {}".format(dtype_x)) diff --git a/python/cuml/test/test_metrics.py b/python/cuml/test/test_metrics.py index 800b2c9771..8d642b72fd 100644 --- a/python/cuml/test/test_metrics.py +++ b/python/cuml/test/test_metrics.py @@ -48,6 +48,7 @@ from sklearn.metrics.cluster import silhouette_score as sk_silhouette_score from sklearn.metrics.cluster import silhouette_samples as sk_silhouette_samples from sklearn.preprocessing import StandardScaler +from sklearn import preprocessing from cuml import LogisticRegression as cu_log from cuml.metrics import hinge_loss as cuml_hinge @@ -861,6 +862,25 @@ def test_log_loss_at_limits(): log_loss(y_true, y_pred) +def ref_dense_pairwise_dist(X, Y=None, metric=None): + # Select sklearn except for Hellinger that + # sklearn doesn't support + if Y is None: + Y = X + if metric == "hellinger": + return naive_hellinger(X, Y) + else: + return sklearn_pairwise_distances(X, Y, metric) + + +def prep_dense_array(array, metric, col_major=0): + if metric == "hellinger": + norm_array = preprocessing.normalize(array, norm="l1") + return np.asfortranarray(norm_array) if col_major else norm_array + else: + return np.asfortranarray(array) if col_major else array + + @pytest.mark.parametrize("metric", PAIRWISE_DISTANCE_METRICS.keys()) @pytest.mark.parametrize("matrix_size", [(5, 4), (1000, 3), (2, 10), (500, 400)]) @@ -869,22 +889,20 @@ def test_pairwise_distances(metric: str, matrix_size, is_col_major): # Test the pairwise_distance helper function. rng = np.random.RandomState(0) - def prep_array(array): - return np.asfortranarray(array) if is_col_major else array - # For fp64, compare at 13 decimals, (2 places less than the ~15 max) - compare_precision = 10 + compare_precision = 6 # Compare to sklearn, single input - X = prep_array(rng.random_sample(matrix_size)) + X = prep_dense_array(rng.random_sample(matrix_size), + metric=metric, col_major=is_col_major) S = pairwise_distances(X, metric=metric) - S2 = sklearn_pairwise_distances(X, metric=metric) + S2 = ref_dense_pairwise_dist(X, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Compare to sklearn, double input with same dimensions Y = X S = pairwise_distances(X, Y, metric=metric) - S2 = sklearn_pairwise_distances(X, Y, metric=metric) + S2 = ref_dense_pairwise_dist(X, Y, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Compare single and double inputs to eachother @@ -893,15 +911,17 @@ def prep_array(array): cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Compare to sklearn, with Y dim != X dim - Y = prep_array(rng.random_sample((2, matrix_size[1]))) + Y = prep_dense_array(rng.random_sample((2, matrix_size[1])), + metric=metric, + col_major=is_col_major) S = pairwise_distances(X, Y, metric=metric) - S2 = sklearn_pairwise_distances(X, Y, metric=metric) + S2 = ref_dense_pairwise_dist(X, Y, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Change precision of one parameter Y = np.asfarray(Y, dtype=np.float32) S = pairwise_distances(X, Y, metric=metric) - S2 = sklearn_pairwise_distances(X, Y, metric=metric) + S2 = ref_dense_pairwise_dist(X, Y, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # For fp32, compare at 5 decimals, (2 places less than the ~7 max) @@ -911,13 +931,14 @@ def prep_array(array): X = np.asfarray(X, dtype=np.float32) Y = np.asfarray(Y, dtype=np.float32) S = pairwise_distances(X, Y, metric=metric) - S2 = sklearn_pairwise_distances(X, Y, metric=metric) + S2 = ref_dense_pairwise_dist(X, Y, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Test sending an int type with convert_dtype=True - Y = prep_array(rng.randint(10, size=Y.shape)) + Y = prep_dense_array(rng.randint(10, size=Y.shape), + metric=metric, col_major=is_col_major) S = pairwise_distances(X, Y, metric=metric, convert_dtype=True) - S2 = sklearn_pairwise_distances(X, Y, metric=metric) + S2 = ref_dense_pairwise_dist(X, Y, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Test that uppercase on the metric name throws an error. @@ -936,8 +957,10 @@ def test_pairwise_distances_sklearn_comparison(metric: str, matrix_size): element_count = matrix_size[0] * matrix_size[1] - X = rng.random_sample(matrix_size) - Y = rng.random_sample(matrix_size) + X = prep_dense_array(rng.random_sample(matrix_size), + metric=metric, col_major=0) + Y = prep_dense_array(rng.random_sample(matrix_size), + metric=metric, col_major=0) # For fp64, compare at 10 decimals, (5 places less than the ~15 max) compare_precision = 10 @@ -946,7 +969,7 @@ def test_pairwise_distances_sklearn_comparison(metric: str, matrix_size): S = pairwise_distances(X, Y, metric=metric) if (element_count <= 2000000): - S2 = sklearn_pairwise_distances(X, Y, metric=metric) + S2 = ref_dense_pairwise_dist(X, Y, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # For fp32, compare at 4 decimals, (3 places less than the ~7 max) @@ -959,7 +982,7 @@ def test_pairwise_distances_sklearn_comparison(metric: str, matrix_size): S = pairwise_distances(X, Y, metric=metric) if (element_count <= 2000000): - S2 = sklearn_pairwise_distances(X, Y, metric=metric) + S2 = ref_dense_pairwise_dist(X, Y, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) @@ -969,8 +992,12 @@ def test_pairwise_distances_one_dimension_order(metric: str): # can break down when using a size of 1 for either dimension rng = np.random.RandomState(2) - Xc = rng.random_sample((1, 4)) - Yc = rng.random_sample((10, 4)) + Xc = prep_dense_array(rng.random_sample((1, 4)), + metric=metric, + col_major=0) + Yc = prep_dense_array(rng.random_sample((10, 4)), + metric=metric, + col_major=0) Xf = np.asfortranarray(Xc) Yf = np.asfortranarray(Yc) @@ -979,52 +1006,54 @@ def test_pairwise_distances_one_dimension_order(metric: str): # Compare to sklearn, C/C order S = pairwise_distances(Xc, Yc, metric=metric) - S2 = sklearn_pairwise_distances(Xc, Yc, metric=metric) + S2 = ref_dense_pairwise_dist(Xc, Yc, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Compare to sklearn, C/F order S = pairwise_distances(Xc, Yf, metric=metric) - S2 = sklearn_pairwise_distances(Xc, Yf, metric=metric) + S2 = ref_dense_pairwise_dist(Xc, Yf, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Compare to sklearn, F/C order S = pairwise_distances(Xf, Yc, metric=metric) - S2 = sklearn_pairwise_distances(Xf, Yc, metric=metric) + S2 = ref_dense_pairwise_dist(Xf, Yc, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Compare to sklearn, F/F order S = pairwise_distances(Xf, Yf, metric=metric) - S2 = sklearn_pairwise_distances(Xf, Yf, metric=metric) + S2 = ref_dense_pairwise_dist(Xf, Yf, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Switch which input has single dimension - Xc = rng.random_sample((1, 4)) - Yc = rng.random_sample((10, 4)) + Xc = prep_dense_array(rng.random_sample((1, 4)), + metric=metric, col_major=0) + Yc = prep_dense_array(rng.random_sample((10, 4)), + metric=metric, col_major=0) Xf = np.asfortranarray(Xc) Yf = np.asfortranarray(Yc) # Compare to sklearn, C/C order S = pairwise_distances(Xc, Yc, metric=metric) - S2 = sklearn_pairwise_distances(Xc, Yc, metric=metric) + S2 = ref_dense_pairwise_dist(Xc, Yc, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Compare to sklearn, C/F order S = pairwise_distances(Xc, Yf, metric=metric) - S2 = sklearn_pairwise_distances(Xc, Yf, metric=metric) + S2 = ref_dense_pairwise_dist(Xc, Yf, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Compare to sklearn, F/C order S = pairwise_distances(Xf, Yc, metric=metric) - S2 = sklearn_pairwise_distances(Xf, Yc, metric=metric) + S2 = ref_dense_pairwise_dist(Xf, Yc, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Compare to sklearn, F/F order S = pairwise_distances(Xf, Yf, metric=metric) - S2 = sklearn_pairwise_distances(Xf, Yf, metric=metric) + S2 = ref_dense_pairwise_dist(Xf, Yf, metric=metric) cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) -@pytest.mark.parametrize("metric", ["haversine", "nan_euclidean", "canberra"]) +@pytest.mark.parametrize("metric", ["haversine", "nan_euclidean"]) def test_pairwise_distances_unsuppored_metrics(metric): rng = np.random.RandomState(3)