From f84b91c94940dbba983fe3d3fe8e1f241414394e Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Fri, 16 Apr 2021 11:09:33 -0700 Subject: [PATCH] `convert_dtype` in ConfusionMatrix (#3754) closes #3567 Authors: - Divye Gala (https://github.com/divyegala) Approvers: - Victor Lafargue (https://github.com/viclafargue) - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/3754 --- python/cuml/metrics/confusion_matrix.py | 24 +++++++++++++++++------- python/cuml/test/test_metrics.py | 5 +++-- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/python/cuml/metrics/confusion_matrix.py b/python/cuml/metrics/confusion_matrix.py index b0cc952784..5dfb92339e 100644 --- a/python/cuml/metrics/confusion_matrix.py +++ b/python/cuml/metrics/confusion_matrix.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,11 +27,12 @@ from cuml.prims.label import make_monotonic -@cuml.internals.api_return_array(get_output_type=True) +@cuml.internals.api_return_any() def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None, - normalize=None) -> CumlArray: + normalize=None, + convert_dtype=False) -> CumlArray: """Compute confusion matrix to evaluate the accuracy of a classification. Parameters @@ -52,6 +53,9 @@ def confusion_matrix(y_true, y_pred, Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population. If None, confusion matrix will not be normalized. + convert_dtype : bool, optional (default = False) + When set to True, the confusion matrix method will automatically + convert the predictions, ground truth, and labels arrays to np.int32. Returns ------- @@ -59,18 +63,24 @@ def confusion_matrix(y_true, y_pred, Confusion matrix. """ y_true, n_rows, n_cols, dtype = \ - input_to_cuml_array(y_true, check_dtype=[cp.int32, cp.int64]) + input_to_cuml_array(y_true, check_dtype=[cp.int32, cp.int64], + convert_to_dtype=(cp.int32 if convert_dtype + else None)) y_pred, _, _, _ = \ - input_to_cuml_array(y_pred, check_dtype=dtype, - check_rows=n_rows, check_cols=n_cols) + input_to_cuml_array(y_pred, check_dtype=[cp.int32, cp.int64], + check_rows=n_rows, check_cols=n_cols, + convert_to_dtype=(cp.int32 if convert_dtype + else None)) if labels is None: labels = sorted_unique_labels(y_true, y_pred) n_labels = len(labels) else: labels, n_labels, _, _ = \ - input_to_cupy_array(labels, check_dtype=dtype, check_cols=1) + input_to_cupy_array(labels, check_dtype=[cp.int32, cp.int64], + convert_to_dtype=(cp.int32 if convert_dtype + else None), check_cols=1) if sample_weight is None: sample_weight = cp.ones(n_rows, dtype=dtype) else: diff --git a/python/cuml/test/test_metrics.py b/python/cuml/test/test_metrics.py index c5c9ff935e..99da0343a1 100644 --- a/python/cuml/test/test_metrics.py +++ b/python/cuml/test/test_metrics.py @@ -653,14 +653,15 @@ def test_confusion_matrix_binary(): @pytest.mark.parametrize('n_samples', [50, 3000, stress_param(500000)]) -@pytest.mark.parametrize('dtype', [np.int32, np.int64]) +@pytest.mark.parametrize('dtype', [np.int32, np.int64, np.float32]) @pytest.mark.parametrize('problem_type', ['binary', 'multiclass']) def test_confusion_matrix_random(n_samples, dtype, problem_type): upper_range = 2 if problem_type == 'binary' else 1000 y_true, y_pred, _, _ = generate_random_labels( lambda rng: rng.randint(0, upper_range, n_samples).astype(dtype)) - cm = confusion_matrix(y_true, y_pred) + convert_dtype = True if dtype == np.float32 else False + cm = confusion_matrix(y_true, y_pred, convert_dtype=convert_dtype) ref = sk_confusion_matrix(y_true, y_pred) cp.testing.assert_array_almost_equal(ref, cm, decimal=4)