diff --git a/rampwf/score_types/balanced_accuracy.py b/rampwf/score_types/balanced_accuracy.py index 8a7015ee..998924a8 100644 --- a/rampwf/score_types/balanced_accuracy.py +++ b/rampwf/score_types/balanced_accuracy.py @@ -1,13 +1,30 @@ -"""Balanced accuracy. +import warnings -From https://github.com/ch-imad/AutoMl_Challenge/blob/2353ec0/Starting_kit/scoring_program/libscores.py#L187 # noqa +import numpy as np + +from sklearn.metrics import confusion_matrix +from sklearn.metrics.classification import _check_targets -See the thread at -https://github.com/rhiever/tpot/issues/108#issuecomment-317067760 -about the different definitions. -""" from .classifier_base import ClassifierBaseScoreType -from .macro_averaged_recall import MacroAveragedRecall + + +def _balanced_accuracy_score(y_true, y_pred, sample_weight=None, + adjusted=True): + """FIXME: port implementation of balanced accuracy from scikit-learn 0.20. + """ + C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight) + with np.errstate(divide='ignore', invalid='ignore'): + per_class = np.diag(C) / C.sum(axis=1) + if np.any(np.isnan(per_class)): + warnings.warn('y_pred contains classes not in y_true') + per_class = per_class[~np.isnan(per_class)] + score = np.mean(per_class) + if adjusted: + n_classes = len(per_class) + chance = 1 / n_classes + score -= chance + score /= 1 - chance + return score class BalancedAccuracy(ClassifierBaseScoreType): @@ -15,13 +32,13 @@ class BalancedAccuracy(ClassifierBaseScoreType): minimum = 0.0 maximum = 1.0 - def __init__(self, name='balanced_accuracy', precision=2): + def __init__(self, name='balanced_accuracy', precision=2, adjusted=True): self.name = name self.precision = precision + self.adjusted = adjusted def __call__(self, y_true_label_index, y_pred_label_index): - mac = MacroAveragedRecall() - tpr = mac(y_true_label_index, y_pred_label_index) - base_tpr = 1. / len(self.label_names) - score = (tpr - base_tpr) / (1 - base_tpr) + score = _balanced_accuracy_score(y_true_label_index, + y_pred_label_index, + adjusted=self.adjusted) return score