diff --git a/Makefile b/Makefile index 447b0aa4..88e643cf 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ install: build: python3 -m pip install --upgrade build - python3 -m build + python3 -m build upload_test: python3 -m twine upload --repository testpypi dist/* diff --git a/examples/worst_case.py b/examples/worst_case.py new file mode 100644 index 00000000..f6c8b04a --- /dev/null +++ b/examples/worst_case.py @@ -0,0 +1,318 @@ +""" +This task evaluates a set of metrics, mostly related to worst-class performance, as described in +(J. Bitterwolf et al., "Classifiers Should Do Well Even on Their Worst Classes", https://openreview.net/forum?id=QxIXCVYJ2WP). +It is motivated by +(R. Balestriero et al., "The Effects of Regularization and Data Augmentation are Class Dependent", https://arxiv.org/abs/2204.03632) + where the authors note that using only accuracy as a metric is not enough to evaluate + the performance of the classifier, as it must not be the same on all classes/groups. +""" +import argparse +import collections +import itertools +import time +import urllib +import requests +import dataclasses +import os +import pathlib + +import numpy as np + +from shifthappens.data import imagenet as sh_imagenet +from shifthappens import benchmark as sh_benchmark +from shifthappens.models import base as sh_models +from shifthappens.tasks.base import Task +from shifthappens.tasks.metrics import Metric +from shifthappens.tasks.task_result import TaskResult + + +@sh_benchmark.register_task( + name="Worst_case", relative_data_folder="worst_case", standalone=True +) +@dataclasses.dataclass +class WorstCase(Task): + resources = ( + ["worstcase", + "restricted_superclass.csv", + "https://anonymous.4open.science/r/worst_classes-B94C/restricted_superclass.csv", + None], + + ["worstcase", + "new_labels.csv", + "https://anonymous.4open.science/r/worst_classes-B94C/new_labels.csv", + None] + ) + + new_labels = None + new_labels_mask = None + superclasses = None + + verbose = True + labels_type = 'val' + n_retries = 5 + max_batch_size: int = 256 + + def download(self, url, data_folder, filename, md5): + + for _ in range(self.n_retries): + try: + r = requests.get(url) + pathlib.Path(data_folder).mkdir(parents=True, exist_ok=True) + open(os.path.join(data_folder, filename), 'wb').write(r.content) + break + except urllib.error.URLError: + print(f"Download of {url} failed; wait 5s and then try again.") + time.sleep(5) + def setup(self): + + # Download resources + for resource in self.resources: + folder_name, file_name, url, md5 = resource + dataset_folder = os.path.join(self.data_root, folder_name) + if not os.path.isfile(os.path.join(dataset_folder, file_name)): + self.download(url, dataset_folder, file_name, md5) + print(f'File {file_name} is in {dataset_folder}.') + # Set the cleaned labels to a property + new_labels = np.array([int(line) for line in open(os.path.join(dataset_folder, 'new_labels.csv'))]) + + if self.labels_type == 'val_clean': + gooduns = new_labels != -1 + self.new_labels = new_labels[gooduns] + elif self.labels_type == 'val': + gooduns = np.full(new_labels.shape, True) + self.new_labels = np.array(sh_imagenet.load_imagenet_targets()) + + self.new_labels_mask = gooduns + + # Set the superclasses to a property + superclass_list = np.array([int(line) for line in open(os.path.join(dataset_folder, 'restricted_superclass.csv'))]) + self.superclasses = [tuple(np.where(superclass_list == i)[0]) for i in range(0, 9)] + + def get_predictions(self) -> np.ndarray: + preds = { + 'predicted_classes': self.probs.argmax(axis=1), + 'class_probabilities': self.probs, + 'confidences_classifier': self.probs.max(axis=1), + } + preds['number_of_class_predictions'] = collections.Counter(preds['predicted_classes']) + return preds + + def standard_accuracy(self): + preds = self.get_predictions() + accuracy = (preds['predicted_classes'] == self.new_labels).mean() + return accuracy + + def classwise_accuracies(self): + preds = self.get_predictions() + clw_acc = {} + for i in set(self.new_labels): + clw_acc[i] = np.equal(preds['predicted_classes'][np.where(self.new_labels == i)], i).mean() + return clw_acc + + def classwise_sample_numbers(self): + clw_sn = {} + for i in set(self.new_labels): + clw_sn[i] = np.sum(self.new_labels == i) + return clw_sn + + def classwise_topk_accuracies(self, k): + preds = self.get_predictions() + clw_topk_acc = {} + for i in set(self.new_labels): + clw_topk_acc[i] = np.equal(i, np.argsort(preds['class_probabilities'][np.where(self.new_labels == i)], axis=1, kind='mergesort')[:, + -k:]).sum(axis=-1).mean() + return clw_topk_acc + + def standard_balanced_topk_accuracy(self, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + return np.array(list(clw_topk_acc.values())).mean() + + def worst_class_accuracy(self): + cwa = self.classwise_accuracies() + worst_item = min(cwa.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_class_topk_accuracy(self, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + worst_item = min(clw_topk_acc.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_balanced_n_classes_accuracy(self, n): + cwa = self.classwise_accuracies() + sorted_cwa = sorted(cwa.items(), key=lambda item: item[1]) + n_worst = sorted_cwa[:n] + return np.array([x[1] for x in n_worst]).mean() + + def worst_heuristic_n_classes_recall(self, n): + cwa = self.classwise_accuracies() + clw_sn = self.classwise_sample_numbers() + sorted_cwa = sorted(cwa.items(), key=lambda item: item[1]) + n_worst = sorted_cwa[:n] + nwc = np.array([v * clw_sn[c] for c, v in n_worst]).sum() / np.array([clw_sn[c] for c, v in n_worst]).sum() + return nwc + + def worst_balanced_n_classes_topk_accuracy(self, n, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + sorted_clw_topk_acc = sorted(clw_topk_acc.items(), key=lambda item: item[1]) + n_worst = sorted_clw_topk_acc[:n] + return np.array([x[1] for x in n_worst]).mean() + + def worst_heuristic_n_classes_topk_recall(self, n, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + clw_sn = self.classwise_sample_numbers() + sorted_clw_topk_acc = sorted(clw_topk_acc.items(), key=lambda item: item[1]) + n_worst = sorted_clw_topk_acc[:n] + nwc = np.array([v * clw_sn[c] for c, v in n_worst]).sum() / np.array([clw_sn[c] for c, v in n_worst]).sum() + return nwc + + def worst_balanced_two_class_binary_accuracy(self): + classes = list(set(self.new_labels)) + binary_accuracies = {} + for i, j in itertools.combinations(classes, 2): + i_labelled = self.probs[np.where(self.new_labels == i)] + j_labelled = self.probs[np.where(self.new_labels == j)] + i_correct = np.greater(i_labelled[:, i], i_labelled[:, j]).mean() + j_correct = np.greater(j_labelled[:, j], j_labelled[:, i]).mean() + binary_accuracies[(i, j)] = (i_correct + j_correct) / 2 + sorted_binary_accuracies = sorted(binary_accuracies.items(), key=lambda item: item[1]) + worst_item = sorted_binary_accuracies[0] + return worst_item[1] + + def worst_balanced_superclass_recall(self): + cwa = self.classwise_accuracies() + scwa = {i: np.array([cwa[c] for c in s]).mean() for i, s in enumerate(self.superclasses)} + worst_item = min(scwa.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_superclass_recall(self): + cwa = self.classwise_accuracies() + clw_sn = self.classwise_sample_numbers() + scwa = {i: np.array([cwa[c] * clw_sn[c] for c in s]).sum() / np.array([clw_sn[c] for c in s]).sum() for i, s in + enumerate(self.superclasses)} + worst_item = min(scwa.items(), key=lambda x: x[1]) + return worst_item[1] + + def intra_superclass_accuracies(self): + isa = {} + original_probs = self.probs.copy() + original_targets = self.new_labels.copy() + for i, s in enumerate(self.superclasses): + self.probs = original_probs.copy() + self.new_labels = original_targets.copy() + + internal_samples = np.isin(self.new_labels, s) + internal_targets = self.new_labels[internal_samples] + internal_probs = self.probs[internal_samples][:, s] + s_targets = np.vectorize(lambda x: s[x]) + self.probs = internal_probs + self.new_labels = internal_targets + internal_preds = s_targets(self.get_predictions()['predicted_classes']) + isa[i] = (internal_preds == internal_targets).mean() + + self.probs = original_probs + self.new_labels = original_targets + + return isa + + def worst_intra_superclass_accuracy(self): + isa = self.intra_superclass_accuracies() + worst_item = min(isa.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_class_precision(self): + preds = self.get_predictions() + classes = list(set(self.new_labels)) + sc = {} + for c in classes: + erroneous_c = (preds['predicted_classes'] == c) * (self.new_labels != c) + correct_c = (preds['predicted_classes'] == c) * (self.new_labels == c) + predicted_c = (preds['predicted_classes'] == c) + if predicted_c.sum(): + sc[c] = correct_c.sum() / predicted_c.sum() # 1-erroneous_c.sum()/predicted_c.sum() + else: + sc[c] = 1 + sorted_sc = sorted(sc.items(), key=lambda item: item[1]) + worst_item = sorted_sc[0] + return worst_item[1] + + def class_confusion(self): + preds = self.get_predictions() + classes = list(set(self.new_labels)) + confusion = np.zeros((len(classes), len(classes))) + for i, c in enumerate(self.new_labels): + confusion[c, preds['predicted_classes'][i]] += 1 + return confusion + + + def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: + + verbose = self.verbose + model.verbose = verbose + + if verbose: + print(f'new labels of type {self.labels_type} are', self.new_labels, len(self.new_labels)) + + self.probs = model.imagenet_validation_result.confidences[self.new_labels_mask, :] + + metrics = { + 'A': self.standard_accuracy, + 'WCA': self.worst_class_accuracy, + 'WCP': self.worst_class_precision, + 'WSupCA': self.worst_intra_superclass_accuracy, + 'WSupCR': self.worst_superclass_recall, + 'W10CR': lambda : self.worst_heuristic_n_classes_recall(10), + 'W100CR': lambda : self.worst_heuristic_n_classes_recall(100), + 'W2CA': self.worst_balanced_two_class_binary_accuracy, + 'WCAat5': lambda : self.worst_class_topk_accuracy(5), + 'W10CRat5': lambda : self.worst_heuristic_n_classes_topk_recall(10, 5), + 'W100CRat5': lambda : self.worst_heuristic_n_classes_topk_recall(100, 5), + } + + metrics_eval = {} + for metric_name, metric in metrics.items(): + if verbose: + print(f'Evaluating {metric_name}') + metrics_eval[metric_name] = metric() + if verbose: + print('metrics are', metrics_eval) + return TaskResult( + summary_metrics={Metric.Fairness: ("A", "WCA", "WCP", "WSupCA", "WSupCR", + "W10CR", "W100CR", "W2CA", "WCAat5", + "W10CRat5", "W100CRat5")}, + **metrics_eval + + ) + + +if __name__ == "__main__": + from shifthappens.models.torchvision import ResNet18 + import shifthappens + + parser = argparse.ArgumentParser() + + # Set the label type either to val (50000 labels) or + # val_clean (46044 labels) for the cleaned labels from + # (C. Northcutt et al., "Pervasive Label Errors in Test Sets Destabilize Machine Learning Benchmarks", https://arxiv.org/abs/2103.14749, https://github.com/cleanlab/label-errors) + parser.add_argument( + "--labels_type", type=str, help="The label type", default='val' + ) + parser.add_argument( + "--imagenet_val_folder", type=str, help="The folder for the imagenet val set", required=True + ) + parser.add_argument( + "--verbose", + help="Turn verbose mode on when set", + action="store_true", + ) + + args = parser.parse_args() + + shifthappens.data.imagenet.ImageNetValidationData=args.imagenet_val_folder + + tuple(sh_benchmark.__registered_tasks)[0].cls.verbose = args.verbose + + tuple(sh_benchmark.__registered_tasks)[0].cls.labels_type = args.labels_type + sh_benchmark.evaluate_model( + ResNet18(device="cuda:2", max_batch_size=500), + "data" + ) diff --git a/requirements.txt b/requirements.txt index ebcc420c..6d015e36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ numpy torch torchvision -surgeon_pytorch \ No newline at end of file +surgeon_pytorch +tqdm diff --git a/shifthappens/models/base.py b/shifthappens/models/base.py index 25a1aa83..f706668e 100644 --- a/shifthappens/models/base.py +++ b/shifthappens/models/base.py @@ -11,8 +11,10 @@ import abc import dataclasses from typing import Iterator +import shifthappens.config import numpy as np +from tqdm import tqdm from shifthappens.data import imagenet as sh_imagenet from shifthappens.data.base import DataLoader @@ -109,6 +111,7 @@ class Model(abc.ABC): def __init__(self): self._imagenet_validation_result = None + self.verbose = False @property def imagenet_validation_result(self): @@ -205,7 +208,13 @@ def _predict_imagenet_val(self): score_type for score_type in score_types if score_types[score_type] ] } - for prediction in self._predict(imagenet_val_dataloader, targets): + + if shifthappens.config.verbose: + pred_loader = tqdm(self._predict(imagenet_val_dataloader, targets), desc='Predictions', total=int(len(imagenet_val_dataloader._dataset)/imagenet_val_dataloader.max_batch_size)) + else: + pred_loader = self._predict(imagenet_val_dataloader, targets) + + for prediction in pred_loader: for prediction_type in predictions_dict: prediction_score = prediction.__getattribute__(prediction_type) predictions_dict[prediction_type] = sum( diff --git a/shifthappens/tasks/raccoons_ood/raccoons_ood.py b/shifthappens/tasks/raccoons_ood/raccoons_ood.py index f6fd223c..e5f6fe5e 100644 --- a/shifthappens/tasks/raccoons_ood/raccoons_ood.py +++ b/shifthappens/tasks/raccoons_ood/raccoons_ood.py @@ -8,6 +8,7 @@ The original dataset was collected by Dat Tran for the object detection task and can be found at https://github.com/datitran/raccoon_dataset. """ + import dataclasses import os diff --git a/shifthappens/tasks/worst_case/README.rst b/shifthappens/tasks/worst_case/README.rst new file mode 100644 index 00000000..5021ffa5 --- /dev/null +++ b/shifthappens/tasks/worst_case/README.rst @@ -0,0 +1,39 @@ +Example for a Shift Happens task on ImageNet +============================================== +# Task Description +This task evaluates a set of metrics, mostly related to worst-class performance, as described in [1]. +It is motivated by [2], where the authors note that using only accuracy as a metric is not enough to evaluate + the performance of the classifier, as it must not be the same on all classes/groups. + +## How to start +in the icml-2022 folder, run + +``` +python shifthappens/tasks/worst_case/worst_case.py --imagenet_val_folder '/scratch/datasets/imagenet/val' --verbose --labels_type 'val' +``` + +for evaluating with original labels, and + +``` +python shifthappens/tasks/worst_case/worst_case.py --imagenet_val_folder '/scratch/datasets/imagenet/val' --verbose --labels_type 'val_clean' +``` + +for evaluating with cleaned ones from [3]. + + +## Evaluation Metrics +The evaluation metrics are "A", "WCA", "WCP", "WSupCA", "WSupCR", "W10CR", "W100CR", "W2CA", "WCAat5", "W10CRat5", "W100CRat5", and their relevance is described in (J. Bitterwolf et al., "Classifiers Should Do Well Even on Their Worst Classes", https://openreview.net/forum?id=QxIXCVYJ2WP). + +## Expected Insights/Relevance +To see the, how the model performs on its worst classes. The application examples are given in [1]. + + +1. Classifiers Should Do Well Even on Their Worst Classes. + J. Bitterwolf et al. 2022. + +2. The Effects of Regularization and Data Augmentation are Class Dependent. + R. Balestriero et al. 2022. + +3. Pervasive Label Errors in Test Sets Destabilize Machine Learning Benchmarks. + C. Northcutt et al. 2021. + diff --git a/shifthappens/tasks/worst_case/__init__.py b/shifthappens/tasks/worst_case/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py new file mode 100644 index 00000000..7995e687 --- /dev/null +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -0,0 +1,385 @@ +"""Classifiers Should Do Well Even on Their Worst Classes""" +import argparse +import collections +import itertools +import time +import urllib +from typing import List, Union + +import requests +import dataclasses +import os +import pathlib +import torch +import numpy as np +import torch.nn as nn +from numpy.core._multiarray_umath import ndarray +import shifthappens.config +from numpy.core.multiarray import ndarray + +from shifthappens.data import imagenet as sh_imagenet +from shifthappens import benchmark as sh_benchmark +from shifthappens.models import base as sh_models +from shifthappens.tasks.base import Task +from shifthappens.tasks.metrics import Metric +from shifthappens.tasks.task_result import TaskResult + + +@sh_benchmark.register_task( + name="Worst_case", relative_data_folder="worst_case", standalone=True +) +@dataclasses.dataclass +class WorstCase(Task): + """This task evaluates a set of metrics, mostly related to worst-class performance, as described in [1]. + It is motivated by [2], where the authors note that using only accuracy as a metric is not enough to evaluate + the performance of the classifier, as it must not be the same on all classes/groups.""" + resources = ( + ["worstcase", + "restricted_superclass.csv", + "https://anonymous.4open.science/r/worst_classes-B94C/restricted_superclass.csv", + None], + + ["worstcase", + "new_labels.csv", + "https://anonymous.4open.science/r/worst_classes-B94C/new_labels.csv", + None] + ) + + new_labels = None + new_labels_mask: Union[ndarray, None, bool] = None + superclasses: List[tuple] = None + + verbose: bool = True + labels_type: str = 'val' + n_retries: int = 5 + max_batch_size: int = 256 + + def download(self, url, data_folder, filename, md5): + """Method to download the data given its' url, and the desired folder to stor int""" + for _ in range(self.n_retries): + try: + r = requests.get(url) + pathlib.Path(data_folder).mkdir(parents=True, exist_ok=True) + open(os.path.join(data_folder, filename), 'wb').write(r.content) + break + except urllib.error.URLError: + print(f"Download of {url} failed; wait 5s and then try again.") + time.sleep(5) + + def setup(self): + """Calls the download method to download the cleaned labels from [3], as well as superclasses used in [1]""" + # Download resources + for resource in self.resources: + folder_name, file_name, url, md5 = resource + dataset_folder = os.path.join(self.data_root, folder_name) + if not os.path.isfile(os.path.join(dataset_folder, file_name)): + self.download(url, dataset_folder, file_name, md5) + print(f'File {file_name} is in {dataset_folder}.') + # Set the cleaned labels to a property + new_labels: ndarray = np.array([int(line) for line in open(os.path.join(dataset_folder, 'new_labels.csv'))]) + + if self.labels_type == 'val_clean': + cleaned_labels = new_labels != -1 + self.new_labels = new_labels[cleaned_labels] + elif self.labels_type == 'val': + cleaned_labels = np.full(new_labels.shape, True) + self.new_labels = np.array(sh_imagenet.load_imagenet_targets()) + + self.new_labels_mask = cleaned_labels + + # Set the superclasses to a property + superclass_list: ndarray = np.array([int(line) for line in open(os.path.join(dataset_folder, 'restricted_superclass.csv'))]) + self.superclasses = [tuple(np.where(superclass_list == i)[0]) for i in range(0, 9)] + + def get_predictions(self) -> np.ndarray: + """Saves to a property as a dict the computed predictions and probabilities for the used model""" + preds = { + 'predicted_classes': self.probs.argmax(axis=1), + 'class_probabilities': self.probs, + 'confidences_classifier': self.probs.max(axis=1), + } + preds['number_of_class_predictions'] = collections.Counter(preds['predicted_classes']) + return preds + + def standard_accuracy(self) -> np.float: + """Computes standard accuracy""" + preds = self.get_predictions() + accuracy = (preds['predicted_classes'] == self.new_labels).mean() + return accuracy + + def classwise_accuracies(self) -> dict: + """Computes accuracies per each class""" + preds = self.get_predictions() + clw_acc = {} + for i in set(self.new_labels): + clw_acc[i] = np.equal(preds['predicted_classes'][np.where(self.new_labels == i)], i).mean() + return clw_acc + + def classwise_sample_numbers(self) -> dict: + """Computes number of samples per class""" + classwise_sample_number = {} + for i in set(self.new_labels): + classwise_sample_number[i] = np.sum(self.new_labels == i) + return classwise_sample_number + + def classwise_topk_accuracies(self, k) -> dict: + """Computes topk accuracies per class""" + preds = self.get_predictions() + classwise_topk_acc = {} + for i in set(self.new_labels): + classwise_topk_acc[i] = np.equal(i, np.argsort(preds['class_probabilities'][np.where(self.new_labels == i)], axis=1, kind='mergesort')[:, + -k:]).sum(axis=-1).mean() + return classwise_topk_acc + + def standard_balanced_topk_accuracy(self, k) -> np.array: + """Computes the balanced topk accuracy""" + classwise_topk_acc = self.classwise_topk_accuracies(k) + return np.array(list(classwise_topk_acc.values())).mean() + + def worst_class_accuracy(self) -> float: + """Computes the smallest accuracy among classes""" + classwise_accuracies = self.classwise_accuracies() + worst_item = min(classwise_accuracies.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_class_topk_accuracy(self, k) -> float: + """Computes the smallest topk accuracy among classes""" + classwise_topk_acc = self.classwise_topk_accuracies(k) + worst_item = min(classwise_topk_acc.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_balanced_n_classes_accuracy(self, n) -> np.array: + """Computes the ballanced accuracy among the worst n classes, based on their per-class accuracies""" + classwise_accuracies = self.classwise_accuracies() + sorted_classwise_accuracies = sorted(classwise_accuracies.items(), key=lambda item: item[1]) + n_worst = sorted_classwise_accuracies[:n] + return np.array([x[1] for x in n_worst]).mean() + + def worst_heuristic_n_classes_recall(self, n) -> np.float: + """Computes recall for n worst in terms of their per class accuracy""" + classwise_accuracies = self.classwise_accuracies() + classwise_accuracies_sample_numbers = self.classwise_sample_numbers() + sorted_classwise_accuracies = sorted(classwise_accuracies.items(), key=lambda item: item[1]) + n_worst = sorted_classwise_accuracies[:n] + n_worstclass_recall = np.array([v * classwise_accuracies_sample_numbers[c] for c, v in n_worst]).sum() / np.array([classwise_accuracies_sample_numbers[c] for c, v in n_worst]).sum() + return n_worstclass_recall + + def worst_balanced_n_classes_topk_accuracy(self, n, k) -> np.float: + """Computes the balanced accuracy for the worst n classes in therms of their per class topk accuracy""" + classwise_topk_accuracies = self.classwise_topk_accuracies(k) + sorted_clw_topk_acc = sorted(classwise_topk_accuracies.items(), key=lambda item: item[1]) + n_worst = sorted_clw_topk_acc[:n] + return np.array([x[1] for x in n_worst]).mean() + + def worst_heuristic_n_classes_topk_recall(self, n, k) -> np.float: + """Computes the recall for the worst n classes in therms of their per class topk accuracy""" + classwise_topk_accuracies = self.classwise_topk_accuracies(k) + clw_sn = self.classwise_sample_numbers() + sorted_clw_topk_acc = sorted(classwise_topk_accuracies.items(), key=lambda item: item[1]) + n_worst = sorted_clw_topk_acc[:n] + n_worstclass_recall = np.array([v * clw_sn[c] for c, v in n_worst]).sum() / np.array([clw_sn[c] for c, v in n_worst]).sum() + return n_worstclass_recall + + def worst_balanced_two_class_binary_accuracy(self) -> np.float: + """Computes the smallest two-class accuracy, when restricting the classifier to any two classes""" + classes = list(set(self.new_labels)) + binary_accuracies = {} + for i, j in itertools.combinations(classes, 2): + i_labelled = self.probs[np.where(self.new_labels == i)] + j_labelled = self.probs[np.where(self.new_labels == j)] + i_correct = np.greater(i_labelled[:, i], i_labelled[:, j]).mean() + j_correct = np.greater(j_labelled[:, j], j_labelled[:, i]).mean() + binary_accuracies[(i, j)] = (i_correct + j_correct) / 2 + sorted_binary_accuracies = sorted(binary_accuracies.items(), key=lambda item: item[1]) + worst_item = sorted_binary_accuracies[0] + return worst_item[1] + + def worst_balanced_superclass_recall(self) -> np.float: + """Computes the worst balanced recall among the superclasses""" + classwise_accuracies = self.classwise_accuracies() + superclass_classwise_accuracies = {i: np.array([classwise_accuracies[c] for c in s]).mean() for i, s in enumerate(self.superclasses)} + worst_item = min(superclass_classwise_accuracies.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_superclass_recall(self) -> np.float: + """Computes the worst not balanced recall among the superclasses""" + classwise_accuracies = self.classwise_accuracies() + classwise_sample_number = self.classwise_sample_numbers() + superclass_classwise_accuracies = {i: np.array([classwise_accuracies[c] * classwise_sample_number[c] for c in s]).sum() / np.array([classwise_sample_number[c] for c in s]).sum() for i, s in + enumerate(self.superclasses)} + worst_item = min(superclass_classwise_accuracies.items(), key=lambda x: x[1]) + return worst_item[1] + + def intra_superclass_accuracies(self) -> dict: + """Computes the accuracy for the images among one superclass, for each superclass""" + intra_superclass_accuracies = {} + original_probs = self.probs.copy() + original_targets = self.new_labels.copy() + for i, s in enumerate(self.superclasses): + self.probs = original_probs.copy() + self.new_labels = original_targets.copy() + + internal_samples = np.isin(self.new_labels, s) + internal_targets = self.new_labels[internal_samples] + internal_probs = self.probs[internal_samples][:, s] + s_targets = np.vectorize(lambda x: s[x]) + self.probs = internal_probs + self.new_labels = internal_targets + internal_preds = s_targets(self.get_predictions()['predicted_classes']) + intra_superclass_accuracies[i] = (internal_preds == internal_targets).mean() + + self.probs = original_probs + self.new_labels = original_targets + + return intra_superclass_accuracies + + def worst_intra_superclass_accuracy(self) -> np.float: + """Computes the worst superclass accuracy using intra_superclass_accuracies + + Output: the accuracy for the worst super class + """ + isa = self.intra_superclass_accuracies() + worst_item = min(isa.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_class_precision(self) -> np.float: + """Computes the precision for the worst class + + Returns: + Dict entry with the worst performing class + """ + preds = self.get_predictions() + classes = list(set(self.new_labels)) + per_class_precision = {} + for c in classes: + erroneous_c = (preds['predicted_classes'] == c) * (self.new_labels != c) + correct_c = (preds['predicted_classes'] == c) * (self.new_labels == c) + predicted_c = (preds['predicted_classes'] == c) + if predicted_c.sum(): + per_class_precision[c] = correct_c.sum() / predicted_c.sum() # 1-erroneous_c.sum()/predicted_c.sum() + else: + per_class_precision[c] = 1 + sorted_sc = sorted(per_class_precision.items(), key=lambda item: item[1]) + worst_item = sorted_sc[0] + return worst_item[1] + + def class_confusion(self) -> np.array: + """Computes the confision matrix + Returns: + confusion: confusion matrx + + """ + preds = self.get_predictions() + classes = list(set(self.new_labels)) + confusion = np.zeros((len(classes), len(classes))) + for i, c in enumerate(self.new_labels): + confusion[c, preds['predicted_classes'][i]] += 1 + return confusion + + + def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: + """The final method that uses all of the above to compute the metrics introduced in [1]""" + verbose = self.verbose + model.verbose = verbose + + if verbose: + print(f'new labels of type {self.labels_type} are', self.new_labels, len(self.new_labels)) + + self.probs = model.imagenet_validation_result.confidences[self.new_labels_mask, :] + + metrics = { + 'A': self.standard_accuracy, + 'WCA': self.worst_class_accuracy, + 'WCP': self.worst_class_precision, + 'WSupCA': self.worst_intra_superclass_accuracy, + 'WSupCR': self.worst_superclass_recall, + 'W10CR': lambda : self.worst_heuristic_n_classes_recall(10), + 'W100CR': lambda : self.worst_heuristic_n_classes_recall(100), + 'W2CA': self.worst_balanced_two_class_binary_accuracy, + 'WCAat5': lambda : self.worst_class_topk_accuracy(5), + 'W10CRat5': lambda : self.worst_heuristic_n_classes_topk_recall(10, 5), + 'W100CRat5': lambda : self.worst_heuristic_n_classes_topk_recall(100, 5), + } + + metrics_eval = {} + for metric_name, metric in metrics.items(): + if verbose: + print(f'Evaluating {metric_name}') + metrics_eval[metric_name] = metric() + if verbose: + print('metrics are', metrics_eval) + return TaskResult( + summary_metrics={Metric.Fairness: ("A", "WCA", "WCP", "WSupCA", "WSupCR", + "W10CR", "W100CR", "W2CA", "WCAat5", + "W10CRat5", "W100CRat5")}, + **metrics_eval + + ) + + +if __name__ == "__main__": + from shifthappens.models.torchvision import * + import shifthappens + + + available_models_dict = {'resnet18': ResNet18, + 'resnet50': ResNet50, + 'vgg16': VGG16} + + parser = argparse.ArgumentParser() + + # Set the label type either to val (50000 labels) or + # val_clean (46044 labels) for the cleaned labels from + # [3] + parser.add_argument( + "--labels_type", type=str, help="The label type", default='val' + ) + parser.add_argument( + "--imagenet_val_folder", type=str, help="The folder for the imagenet val set", required=True + ) + parser.add_argument( + "--model_name", type=str, default='resnet18', + help=f'The name of the model to test. Should be in {available_models_dict.keys()}' + ) + parser.add_argument('--gpu', '--list', nargs='+', default=[0], + help='GPU indices, if more than 1 parallel modules will be called') + parser.add_argument('--bs', type=int, default=500) + parser.add_argument( + "--verbose", + help="Turn verbose mode on when set", + action="store_true", + ) + + args = parser.parse_args() + + if len(args.gpu) == 0: + device_ids = None + device = torch.device('cpu') + print('Warning! Computing on CPU') + num_devices = 1 + elif len(args.gpu) == 1: + device_ids = [int(args.gpu[0])] + device = torch.device('cuda:' + str(args.gpu[0])) + num_devices = 1 + else: + device_ids = [int(i) for i in args.gpu] + device = torch.device('cuda:' + str(min(device_ids))) + num_devices = len(device_ids) + + shifthappens.config.imagenet_validation_path = args.imagenet_val_folder + + + shifthappens.config.verbose = args.verbose + + tuple(sh_benchmark.__registered_tasks)[0].cls.labels_type = args.labels_type + + assert args.model_name.lower() in available_models_dict, f"Selected model_name should be in {available_models_dict.keys()}" + + model = available_models_dict[args.model_name.lower()](device=device, max_batch_size=args.bs) + + if device_ids is not None and len(device_ids) > 1: + model = nn.DataParallel(model, device_ids=device_ids) + sh_benchmark.evaluate_model( + model, + "data" + )