diff --git a/cpp/include/cuml/metrics/metrics.hpp b/cpp/include/cuml/metrics/metrics.hpp index f1f9a3d218..123359878f 100644 --- a/cpp/include/cuml/metrics/metrics.hpp +++ b/cpp/include/cuml/metrics/metrics.hpp @@ -295,6 +295,7 @@ double completeness_score(const raft::handle_t& handle, * @param n: Number of elements in y and y_hat * @param lower_class_range: the lowest value in the range of classes * @param upper_class_range: the highest value in the range of classes + * @param beta: Ratio of weight attributed to homogeneity vs completeness * @return: The v-measure */ double v_measure(const raft::handle_t& handle, @@ -302,7 +303,8 @@ double v_measure(const raft::handle_t& handle, const int* y_hat, const int n, const int lower_class_range, - const int upper_class_range); + const int upper_class_range, + double beta); /** * Calculates the "accuracy" between two input numpy arrays/ cudf series diff --git a/cpp/src/metrics/v_measure.cu b/cpp/src/metrics/v_measure.cu index f71091543a..0a79f8020a 100644 --- a/cpp/src/metrics/v_measure.cu +++ b/cpp/src/metrics/v_measure.cu @@ -1,6 +1,6 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,10 +27,11 @@ double v_measure(const raft::handle_t& handle, const int* y_hat, const int n, const int lower_class_range, - const int upper_class_range) + const int upper_class_range, + double beta) { return MLCommon::Metrics::v_measure( - y, y_hat, n, lower_class_range, upper_class_range, handle.get_stream()); + y, y_hat, n, lower_class_range, upper_class_range, handle.get_stream(), beta); } } // namespace Metrics } // namespace ML diff --git a/cpp/src_prims/metrics/v_measure.cuh b/cpp/src_prims/metrics/v_measure.cuh index e0396c5702..9ce462271e 100644 --- a/cpp/src_prims/metrics/v_measure.cuh +++ b/cpp/src_prims/metrics/v_measure.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/docs/source/api.rst b/docs/source/api.rst index f0ec1da5cf..f322f25a95 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -229,6 +229,9 @@ Metrics (clustering and manifold learning) .. automodule:: cuml.metrics.cluster.mutual_info_score :members: + .. automodule:: cuml.metrics.cluster.v_measure_score + :members: + Benchmarking ------------- diff --git a/python/cuml/metrics/__init__.py b/python/cuml/metrics/__init__.py index 493f6c7c47..3d3791a8ee 100644 --- a/python/cuml/metrics/__init__.py +++ b/python/cuml/metrics/__init__.py @@ -40,6 +40,8 @@ from cuml.metrics.pairwise_kernels import PAIRWISE_KERNEL_FUNCTIONS from cuml.metrics.hinge_loss import hinge_loss from cuml.metrics.kl_divergence import kl_divergence +from cuml.metrics.cluster.v_measure import \ + cython_v_measure as v_measure_score __all__ = [ "trustworthiness", @@ -62,4 +64,5 @@ "pairwise_kernels", "hinge_loss", "kl_divergence", + "v_measure_score" ] diff --git a/python/cuml/metrics/cluster/__init__.py b/python/cuml/metrics/cluster/__init__.py index 95a263ec6e..023e599789 100644 --- a/python/cuml/metrics/cluster/__init__.py +++ b/python/cuml/metrics/cluster/__init__.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,3 +26,5 @@ cython_silhouette_score as silhouette_score from cuml.metrics.cluster.silhouette_score import \ cython_silhouette_samples as silhouette_samples +from cuml.metrics.cluster.v_measure import \ + cython_v_measure as v_measure_score diff --git a/python/cuml/metrics/cluster/v_measure.pyx b/python/cuml/metrics/cluster/v_measure.pyx new file mode 100644 index 0000000000..fa98cff026 --- /dev/null +++ b/python/cuml/metrics/cluster/v_measure.pyx @@ -0,0 +1,103 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# distutils: language = c++ + +import cuml.internals +from raft.common.handle cimport handle_t +from libc.stdint cimport uintptr_t +from cuml.metrics.cluster.utils import prepare_cluster_metric_inputs +from raft.common.handle import Handle + + +cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics": + double v_measure(const handle_t & handle, + const int * y, + const int * y_hat, + const int n, + const int lower_class_range, + const int upper_class_range, + const double beta) except + + + +@cuml.internals.api_return_any() +def cython_v_measure(labels_true, labels_pred, beta=1.0, handle=None) -> float: + """ + V-measure metric of a cluster labeling given a ground truth. + + The V-measure is the harmonic mean between homogeneity and completeness:: + + v = (1 + beta) * homogeneity * completeness + / (beta * homogeneity + completeness) + + This metric is independent of the absolute values of the labels: + a permutation of the class or cluster label values won't change the + score value in any way. + + This metric is furthermore symmetric: switching ``label_true`` with + ``label_pred`` will return the same score value. This can be useful to + measure the agreement of two independent label assignments strategies + on the same dataset when the real ground truth is not known. + + Parameters + ---------- + labels_pred : array-like (device or host) shape = (n_samples,) + The labels predicted by the model for the test dataset. + Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device + ndarray, cuda array interface compliant array like CuPy + labels_true : array-like (device or host) shape = (n_samples,) + The ground truth labels (ints) of the test dataset. + Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device + ndarray, cuda array interface compliant array like CuPy + beta : float, default=1.0 + Ratio of weight attributed to ``homogeneity`` vs ``completeness``. + If ``beta`` is greater than 1, ``completeness`` is weighted more + strongly in the calculation. If ``beta`` is less than 1, + ``homogeneity`` is weighted more strongly. + handle : cuml.Handle + Specifies the cuml.handle that holds internal CUDA state for + computations in this model. Most importantly, this specifies the CUDA + stream that will be used for the model's computations, so users can + run different models concurrently in different streams by creating + handles in several streams. + If it is None, a new one is created. + + Returns + ------- + v_measure_value : float + score between 0.0 and 1.0. 1.0 stands for perfectly complete labeling + """ + handle = Handle() if handle is None else handle + cdef handle_t *handle_ = handle.getHandle() + + (y_true, y_pred, n_rows, + lower_class_range, upper_class_range) = prepare_cluster_metric_inputs( + labels_true, + labels_pred + ) + + cdef uintptr_t ground_truth_ptr = y_true.ptr + cdef uintptr_t preds_ptr = y_pred.ptr + + v_measure_value = v_measure(handle_[0], + ground_truth_ptr, + preds_ptr, + n_rows, + lower_class_range, + upper_class_range, + beta) + + return v_measure_value diff --git a/python/cuml/tests/test_metrics.py b/python/cuml/tests/test_metrics.py index 7f03838342..4ded86b86b 100644 --- a/python/cuml/tests/test_metrics.py +++ b/python/cuml/tests/test_metrics.py @@ -81,6 +81,8 @@ from sklearn.metrics import pairwise_distances as sklearn_pairwise_distances from scipy.spatial import distance as scipy_pairwise_distances from scipy.special import rel_entr as scipy_kl_divergence +from sklearn.metrics.cluster import v_measure_score as sklearn_v_measure_score +from cuml.metrics.cluster import v_measure_score @pytest.fixture(scope='module') @@ -1485,3 +1487,12 @@ def test_mean_squared_error_cudf_series(): err1 = mean_squared_error(a, b) err2 = mean_squared_error(a.values, b.values) assert err1 == err2 + + +@pytest.mark.parametrize("beta", [0.0, 0.5, 1.0, 2.0]) +def test_v_measure_score(beta): + labels_true = np.array([0, 0, 1, 1], dtype=np.int32) + labels_pred = np.array([1, 0, 1, 1], dtype=np.int32) + res = v_measure_score(labels_true, labels_pred, beta=beta) + ref = sklearn_v_measure_score(labels_true, labels_pred, beta=beta) + assert_almost_equal(res, ref)