Skip to content

Commit

Permalink
convert_dtype in ConfusionMatrix (rapidsai#3754)
Browse files Browse the repository at this point in the history
closes rapidsai#3567

Authors:
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Victor Lafargue (https://github.com/viclafargue)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#3754
  • Loading branch information
divyegala authored Apr 16, 2021
1 parent b8476fb commit f84b91c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
24 changes: 17 additions & 7 deletions python/cuml/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,11 +27,12 @@
from cuml.prims.label import make_monotonic


@cuml.internals.api_return_array(get_output_type=True)
@cuml.internals.api_return_any()
def confusion_matrix(y_true, y_pred,
labels=None,
sample_weight=None,
normalize=None) -> CumlArray:
normalize=None,
convert_dtype=False) -> CumlArray:
"""Compute confusion matrix to evaluate the accuracy of a classification.
Parameters
Expand All @@ -52,25 +53,34 @@ def confusion_matrix(y_true, y_pred,
Normalizes confusion matrix over the true (rows), predicted (columns)
conditions or all the population. If None, confusion matrix will not be
normalized.
convert_dtype : bool, optional (default = False)
When set to True, the confusion matrix method will automatically
convert the predictions, ground truth, and labels arrays to np.int32.
Returns
-------
C : array-like (device or host) shape = (n_classes, n_classes)
Confusion matrix.
"""
y_true, n_rows, n_cols, dtype = \
input_to_cuml_array(y_true, check_dtype=[cp.int32, cp.int64])
input_to_cuml_array(y_true, check_dtype=[cp.int32, cp.int64],
convert_to_dtype=(cp.int32 if convert_dtype
else None))

y_pred, _, _, _ = \
input_to_cuml_array(y_pred, check_dtype=dtype,
check_rows=n_rows, check_cols=n_cols)
input_to_cuml_array(y_pred, check_dtype=[cp.int32, cp.int64],
check_rows=n_rows, check_cols=n_cols,
convert_to_dtype=(cp.int32 if convert_dtype
else None))

if labels is None:
labels = sorted_unique_labels(y_true, y_pred)
n_labels = len(labels)
else:
labels, n_labels, _, _ = \
input_to_cupy_array(labels, check_dtype=dtype, check_cols=1)
input_to_cupy_array(labels, check_dtype=[cp.int32, cp.int64],
convert_to_dtype=(cp.int32 if convert_dtype
else None), check_cols=1)
if sample_weight is None:
sample_weight = cp.ones(n_rows, dtype=dtype)
else:
Expand Down
5 changes: 3 additions & 2 deletions python/cuml/test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,14 +653,15 @@ def test_confusion_matrix_binary():


@pytest.mark.parametrize('n_samples', [50, 3000, stress_param(500000)])
@pytest.mark.parametrize('dtype', [np.int32, np.int64])
@pytest.mark.parametrize('dtype', [np.int32, np.int64, np.float32])
@pytest.mark.parametrize('problem_type', ['binary', 'multiclass'])
def test_confusion_matrix_random(n_samples, dtype, problem_type):
upper_range = 2 if problem_type == 'binary' else 1000

y_true, y_pred, _, _ = generate_random_labels(
lambda rng: rng.randint(0, upper_range, n_samples).astype(dtype))
cm = confusion_matrix(y_true, y_pred)
convert_dtype = True if dtype == np.float32 else False
cm = confusion_matrix(y_true, y_pred, convert_dtype=convert_dtype)
ref = sk_confusion_matrix(y_true, y_pred)
cp.testing.assert_array_almost_equal(ref, cm, decimal=4)

Expand Down

0 comments on commit f84b91c

Please sign in to comment.