From ec36db11dc3d4221588e03b05a3f9c98964ada8c Mon Sep 17 00:00:00 2001 From: valentyn Date: Mon, 27 Jun 2022 18:13:39 +0200 Subject: [PATCH 01/20] Proposed task in worst_case.py --- examples/worst_case.py | 320 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 320 insertions(+) create mode 100644 examples/worst_case.py diff --git a/examples/worst_case.py b/examples/worst_case.py new file mode 100644 index 00000000..3ace7b5f --- /dev/null +++ b/examples/worst_case.py @@ -0,0 +1,320 @@ +""" +This task evaluates a set of metrics, mostly related to worst-class performance, as described in +(J. Bitterwolf et al., "Classifiers Should Do Well Even on Their Worst Classes", https://openreview.net/forum?id=QxIXCVYJ2WP). +It is motivated by +(R. Balestriero et al., "The Effects of Regularization and Data Augmentation are Class Dependent", https://arxiv.org/abs/2204.03632) + where the authors note that using only accuracy as a metric is not enough to evaluate + the performance of the classifier, as it must not be the same on all classes/groups. +""" +import argparse +import collections +import itertools +import time +import urllib +import requests +import dataclasses +import os +import pathlib + +import numpy as np + +from shifthappens.data import imagenet as sh_imagenet +from shifthappens import benchmark as sh_benchmark +from shifthappens.models import base as sh_models +from shifthappens.tasks.base import Task +from shifthappens.tasks.metrics import Metric +from shifthappens.tasks.task_result import TaskResult + + +@sh_benchmark.register_task( + name="Worst_case", relative_data_folder="worst_case", standalone=True +) +@dataclasses.dataclass +class WorstCase(Task): + resources = ( + ["worstcase", + "restricted_superclass.csv", + "https://anonymous.4open.science/r/worst_classes-B94C/restricted_superclass.csv", + None], + + ["worstcase", + "new_labels.csv", + "https://anonymous.4open.science/r/worst_classes-B94C/new_labels.csv", + None] + ) + + new_labels = None + new_labels_mask = None + superclasses = None + + verbose = True + labels_type = 'val' + n_retries = 5 + max_batch_size: int = 256 + + def download(self, url, data_folder, filename, md5): + + for _ in range(self.n_retries): + try: + r = requests.get(url) + pathlib.Path(data_folder).mkdir(parents=True, exist_ok=True) + open(os.path.join(data_folder, filename), 'wb').write(r.content) + break + except urllib.error.URLError: + print(f"Download of {url} failed; wait 5s and then try again.") + time.sleep(5) + def setup(self): + + # Download resources + for resource in self.resources: + folder_name, file_name, url, md5 = resource + dataset_folder = os.path.join(self.data_root, folder_name) + if not os.path.isfile(os.path.join(dataset_folder, file_name)): + self.download(url, dataset_folder, file_name, md5) + print(f'File {file_name} is in {dataset_folder}.') + # Set the cleaned labels to a property + new_labels = np.array([int(line) for line in open(os.path.join(dataset_folder, 'new_labels.csv'))]) + + if self.labels_type == 'val_clean': + gooduns = new_labels != -1 + self.new_labels = new_labels[gooduns] + elif self.labels_type == 'val': + gooduns = np.full(new_labels.shape, True) + self.new_labels = np.array(sh_imagenet.load_imagenet_targets()) + + self.new_labels_mask = gooduns + + # Set the superclasses to a property + superclass_list = np.array([int(line) for line in open(os.path.join(dataset_folder, 'restricted_superclass.csv'))]) + self.superclasses = [tuple(np.where(superclass_list == i)[0]) for i in range(0, 9)] + + def get_predictions(self) -> np.ndarray: + preds = { + 'predicted_classes': self.probs.argmax(axis=1), + 'class_probabilities': self.probs, + 'confidences_classifier': self.probs.max(axis=1), + } + preds['number_of_class_predictions'] = collections.Counter(preds['predicted_classes']) + return preds + + def standard_accuracy(self): + preds = self.get_predictions() + accuracy = (preds['predicted_classes'] == self.new_labels).mean() + return accuracy + + def classwise_accuracies(self): + preds = self.get_predictions() + clw_acc = {} + for i in set(self.new_labels): + clw_acc[i] = np.equal(preds['predicted_classes'][np.where(self.new_labels == i)], i).mean() + return clw_acc + + def classwise_sample_numbers(self): + clw_sn = {} + for i in set(self.new_labels): + clw_sn[i] = np.sum(self.new_labels == i) + return clw_sn + + def classwise_topk_accuracies(self, k): + preds = self.get_predictions() + clw_topk_acc = {} + for i in set(self.new_labels): + clw_topk_acc[i] = np.equal(i, np.argsort(preds['class_probabilities'][np.where(self.new_labels == i)], axis=1, kind='mergesort')[:, + -k:]).sum(axis=-1).mean() + return clw_topk_acc + + def standard_balanced_topk_accuracy(self, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + return np.array(list(clw_topk_acc.values())).mean() + + def worst_class_accuracy(self): + cwa = self.classwise_accuracies() + worst_item = min(cwa.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_class_topk_accuracy(self, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + worst_item = min(clw_topk_acc.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_balanced_n_classes_accuracy(self, n): + cwa = self.classwise_accuracies() + sorted_cwa = sorted(cwa.items(), key=lambda item: item[1]) + n_worst = sorted_cwa[:n] + return np.array([x[1] for x in n_worst]).mean() + + def worst_heuristic_n_classes_recall(self, n): + cwa = self.classwise_accuracies() + clw_sn = self.classwise_sample_numbers() + sorted_cwa = sorted(cwa.items(), key=lambda item: item[1]) + n_worst = sorted_cwa[:n] + nwc = np.array([v * clw_sn[c] for c, v in n_worst]).sum() / np.array([clw_sn[c] for c, v in n_worst]).sum() + return nwc + + def worst_balanced_n_classes_topk_accuracy(self, n, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + sorted_clw_topk_acc = sorted(clw_topk_acc.items(), key=lambda item: item[1]) + n_worst = sorted_clw_topk_acc[:n] + return np.array([x[1] for x in n_worst]).mean() + + def worst_heuristic_n_classes_topk_recall(self, n, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + clw_sn = self.classwise_sample_numbers() + sorted_clw_topk_acc = sorted(clw_topk_acc.items(), key=lambda item: item[1]) + n_worst = sorted_clw_topk_acc[:n] + nwc = np.array([v * clw_sn[c] for c, v in n_worst]).sum() / np.array([clw_sn[c] for c, v in n_worst]).sum() + return nwc + + def worst_balanced_two_class_binary_accuracy(self): + classes = list(set(self.new_labels)) + binary_accuracies = {} + for i, j in itertools.combinations(classes, 2): + i_labelled = self.probs[np.where(self.new_labels == i)] + j_labelled = self.probs[np.where(self.new_labels == j)] + i_correct = np.greater(i_labelled[:, i], i_labelled[:, j]).mean() + j_correct = np.greater(j_labelled[:, j], j_labelled[:, i]).mean() + binary_accuracies[(i, j)] = (i_correct + j_correct) / 2 + sorted_binary_accuracies = sorted(binary_accuracies.items(), key=lambda item: item[1]) + worst_item = sorted_binary_accuracies[0] + return worst_item[1] + + def worst_balanced_superclass_recall(self): + cwa = self.classwise_accuracies() + scwa = {i: np.array([cwa[c] for c in s]).mean() for i, s in enumerate(self.superclasses)} + worst_item = min(scwa.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_superclass_recall(self): + cwa = self.classwise_accuracies() + clw_sn = self.classwise_sample_numbers() + scwa = {i: np.array([cwa[c] * clw_sn[c] for c in s]).sum() / np.array([clw_sn[c] for c in s]).sum() for i, s in + enumerate(self.superclasses)} + worst_item = min(scwa.items(), key=lambda x: x[1]) + return worst_item[1] + + def intra_superclass_accuracies(self): + isa = {} + original_probs = self.probs.copy() + original_targets = self.new_labels.copy() + for i, s in enumerate(self.superclasses): + self.probs = original_probs.copy() + self.new_labels = original_targets.copy() + + internal_samples = np.isin(self.new_labels, s) + internal_targets = self.new_labels[internal_samples] + internal_probs = self.probs[internal_samples][:, s] + s_targets = np.vectorize(lambda x: s[x]) + self.probs = internal_probs + self.new_labels = internal_targets + internal_preds = s_targets(self.get_predictions()['predicted_classes']) + isa[i] = (internal_preds == internal_targets).mean() + + self.probs = original_probs + self.new_labels = original_targets + + return isa + + def worst_intra_superclass_accuracy(self): + isa = self.intra_superclass_accuracies() + worst_item = min(isa.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_class_precision(self): + preds = self.get_predictions() + classes = list(set(self.new_labels)) + sc = {} + for c in classes: + erroneous_c = (preds['predicted_classes'] == c) * (self.new_labels != c) + correct_c = (preds['predicted_classes'] == c) * (self.new_labels == c) + predicted_c = (preds['predicted_classes'] == c) + if predicted_c.sum(): + sc[c] = correct_c.sum() / predicted_c.sum() # 1-erroneous_c.sum()/predicted_c.sum() + else: + sc[c] = 1 + sorted_sc = sorted(sc.items(), key=lambda item: item[1]) + worst_item = sorted_sc[0] + return worst_item[1] + + def class_confusion(self): + preds = self.get_predictions() + classes = list(set(self.new_labels)) + confusion = np.zeros((len(classes), len(classes))) + for i, c in enumerate(self.new_labels): + confusion[c, preds['predicted_classes'][i]] += 1 + return confusion + + + def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: + + verbose = self.verbose + model.verbose = verbose + + if verbose: + print(f'new labels of type {self.labels_type} are', self.new_labels, len(self.new_labels)) + + self.probs = model.imagenet_validation_result.confidences[self.new_labels_mask, :] + + metrics = { + 'A': self.standard_accuracy, + 'WCA': self.worst_class_accuracy, + 'WCP': self.worst_class_precision, + 'WSupCA': self.worst_intra_superclass_accuracy, + 'WSupCR': self.worst_superclass_recall, + 'W10CR': lambda : self.worst_heuristic_n_classes_recall(10), + 'W100CR': lambda : self.worst_heuristic_n_classes_recall(100), + 'W2CA': self.worst_balanced_two_class_binary_accuracy, + 'WCAat5': lambda : self.worst_class_topk_accuracy(5), + 'W10CRat5': lambda : self.worst_heuristic_n_classes_topk_recall(10, 5), + 'W100CRat5': lambda : self.worst_heuristic_n_classes_topk_recall(100, 5), + } + + metrics_eval = {} + for metric_name, metric in metrics.items(): + if verbose: + print(f'Evaluating {metric_name}') + metrics_eval[metric_name] = metric() + if verbose: + print('metrics are', metrics_eval) + return TaskResult( + summary_metrics={Metric.Fairness: ("A", "WCA", "WCP", "WSupCA", "WSupCR", + "W10CR", "W100CR", "W2CA", "WCAat5", + "W10CRat5", "W100CRat5")}, + **metrics_eval + + ) + + +if __name__ == "__main__": + from shifthappens.models.torchvision import ResNet18 + import shifthappens + + parser = argparse.ArgumentParser() + + # Set the label type either to val (50000 labels) or + # val_clean (46044 labels) for the cleaned labels from + # (C. Northcutt et al., "Pervasive Label Errors in Test Sets Destabilize Machine Learning Benchmarks", https://arxiv.org/abs/2103.14749, https://github.com/cleanlab/label-errors) + parser.add_argument( + "--labels_type", type=str, help="The label type", default='val' + ) + parser.add_argument( + "--imagenet_val_folder", type=str, help="The folder for the imagenet val set", required=True + ) + parser.add_argument( + "--verbose", + help="Turn verbose mode on when set", + action="store_true", + ) + + args = parser.parse_args() + + shifthappens.data.imagenet.ImageNetValidationData=args.imagenet_val_folder + + + tuple(sh_benchmark.__registered_tasks)[0].cls.verbose = args.verbose + + tuple(sh_benchmark.__registered_tasks)[0].cls.labels_type = args.labels_type + sh_benchmark.evaluate_model( + ResNet18(device="cuda:2", max_batch_size=500), + "data", + verbose=args.verbose + ) From d2a0ea6b4fb5f708b8bc3c74d9e7ec660d7c7794 Mon Sep 17 00:00:00 2001 From: valentyn Date: Mon, 27 Jun 2022 18:15:04 +0200 Subject: [PATCH 02/20] Proposed task in worst_case.py --- examples/worst_case.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/worst_case.py b/examples/worst_case.py index 3ace7b5f..38291e95 100644 --- a/examples/worst_case.py +++ b/examples/worst_case.py @@ -315,6 +315,5 @@ def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: tuple(sh_benchmark.__registered_tasks)[0].cls.labels_type = args.labels_type sh_benchmark.evaluate_model( ResNet18(device="cuda:2", max_batch_size=500), - "data", - verbose=args.verbose + "data" ) From a5547e9fee1370257e60fd43ac42eb68258d8c2d Mon Sep 17 00:00:00 2001 From: valentyn Date: Mon, 27 Jun 2022 18:23:58 +0200 Subject: [PATCH 03/20] Proposed task in worst_case.py --- requirements.txt | 4 +++- setup.cfg | 4 +--- shifthappens/benchmark.py | 4 ++-- shifthappens/data/imagenet.py | 2 +- shifthappens/models/base.py | 10 +++++++++- shifthappens/tasks/base.py | 6 +++--- 6 files changed, 19 insertions(+), 11 deletions(-) diff --git a/requirements.txt b/requirements.txt index ebcc420c..b50cb51e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ numpy torch torchvision -surgeon_pytorch \ No newline at end of file +surgeon_pytorch +sklearn +tqdm diff --git a/setup.cfg b/setup.cfg index f92b7796..6e101e5a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,9 +42,7 @@ license_file = LICENSE license = Apache 2.0 [options] -package_dir = - = . -py_modules = shifthappens +packages = find: python_requires = >=3.8 [options.extras_require] diff --git a/shifthappens/benchmark.py b/shifthappens/benchmark.py index 4e655de8..dfe04bd5 100644 --- a/shifthappens/benchmark.py +++ b/shifthappens/benchmark.py @@ -114,7 +114,7 @@ def unregister_task(cls: Type[Task]): def evaluate_model( - model: Model, data_root: str + model: Model, data_root: str, verbose=False ) -> Dict[TaskRegistration, Optional[TaskResult]]: """ Runs all registered tasks of the benchmark @@ -149,5 +149,5 @@ def evaluate_model( ): task.setup() flavored_task_metadata = getattr(task, task_metadata._TASK_METADATA_FIELD) - results[flavored_task_metadata] = task.evaluate(model) + results[flavored_task_metadata] = task.evaluate(model, verbose) return results diff --git a/shifthappens/data/imagenet.py b/shifthappens/data/imagenet.py index 2e9dd06d..c70a3cf5 100644 --- a/shifthappens/data/imagenet.py +++ b/shifthappens/data/imagenet.py @@ -35,7 +35,7 @@ def _check_imagenet_folder(): ) assert ( - len(os.listdir(ImageNetValidationData)) == 1000 + len(os.listdir(ImageNetValidationData)) >= 1000 ), "ImageNetValidationData folder contains less or more folders than ImageNet classes." diff --git a/shifthappens/models/base.py b/shifthappens/models/base.py index 231bb31f..d49765cb 100644 --- a/shifthappens/models/base.py +++ b/shifthappens/models/base.py @@ -13,6 +13,7 @@ from typing import Iterator import numpy as np +from tqdm import tqdm from shifthappens.data import imagenet as sh_imagenet from shifthappens.data.base import DataLoader @@ -109,6 +110,7 @@ class Model(abc.ABC): def __init__(self): self._imagenet_validation_result = None + self.verbose = False @property def imagenet_validation_result(self): @@ -206,7 +208,13 @@ def _predict_imagenet_val(self): score_type for score_type in score_types if score_types[score_type] ] } - for prediction in self._predict(imagenet_val_dataloader, targets): + + if self.verbose: + pred_loader = tqdm(self._predict(imagenet_val_dataloader, targets), desc='Predictions', total=int(len(imagenet_val_dataloader._dataset)/imagenet_val_dataloader.max_batch_size)) + else: + pred_loader = self._predict(imagenet_val_dataloader, targets) + + for prediction in pred_loader: for prediction_type in predictions_dict: prediction_score = prediction.__getattribute__(prediction_type) predictions_dict[prediction_type] = sum( diff --git a/shifthappens/tasks/base.py b/shifthappens/tasks/base.py index 4625e0f5..c8d195de 100644 --- a/shifthappens/tasks/base.py +++ b/shifthappens/tasks/base.py @@ -212,7 +212,7 @@ def setup(self): """ pass - def evaluate(self, model: sh_models.Model) -> Optional[TaskResult]: + def evaluate(self, model: sh_models.Model, verbose=False) -> Optional[TaskResult]: """Validates that the model is compatible with the task and then evaluates the model's performance using the :py:meth:`_evaluate` function of this class. @@ -238,7 +238,7 @@ def evaluate(self, model: sh_models.Model) -> Optional[TaskResult]: dataloader = self._prepare_dataloader() if dataloader is not None: model.prepare(dataloader) - return self._evaluate(model) + return self._evaluate(model, verbose) def _prepare_dataloader(self) -> Optional[DataLoader]: """Prepare a :py:class:`shifthappens.data.base.DataLoader` based on just the *unlabeled* images which will be passed to the model @@ -263,7 +263,7 @@ def _prepare_dataloader(self) -> Optional[DataLoader]: return None @abstractmethod - def _evaluate(self, model: sh_models.Model) -> TaskResult: + def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: """Evaluate the task and return a dictionary with the calculated metrics. Args: From 7be8dafea2a78b53b3446393576a685b93f82ca3 Mon Sep 17 00:00:00 2001 From: valentyn Date: Mon, 27 Jun 2022 18:32:36 +0200 Subject: [PATCH 04/20] Proposed task in worst_case.py --- shifthappens/benchmark.py | 4 ++-- shifthappens/tasks/base.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/shifthappens/benchmark.py b/shifthappens/benchmark.py index dfe04bd5..4e655de8 100644 --- a/shifthappens/benchmark.py +++ b/shifthappens/benchmark.py @@ -114,7 +114,7 @@ def unregister_task(cls: Type[Task]): def evaluate_model( - model: Model, data_root: str, verbose=False + model: Model, data_root: str ) -> Dict[TaskRegistration, Optional[TaskResult]]: """ Runs all registered tasks of the benchmark @@ -149,5 +149,5 @@ def evaluate_model( ): task.setup() flavored_task_metadata = getattr(task, task_metadata._TASK_METADATA_FIELD) - results[flavored_task_metadata] = task.evaluate(model, verbose) + results[flavored_task_metadata] = task.evaluate(model) return results diff --git a/shifthappens/tasks/base.py b/shifthappens/tasks/base.py index c8d195de..4625e0f5 100644 --- a/shifthappens/tasks/base.py +++ b/shifthappens/tasks/base.py @@ -212,7 +212,7 @@ def setup(self): """ pass - def evaluate(self, model: sh_models.Model, verbose=False) -> Optional[TaskResult]: + def evaluate(self, model: sh_models.Model) -> Optional[TaskResult]: """Validates that the model is compatible with the task and then evaluates the model's performance using the :py:meth:`_evaluate` function of this class. @@ -238,7 +238,7 @@ def evaluate(self, model: sh_models.Model, verbose=False) -> Optional[TaskResult dataloader = self._prepare_dataloader() if dataloader is not None: model.prepare(dataloader) - return self._evaluate(model, verbose) + return self._evaluate(model) def _prepare_dataloader(self) -> Optional[DataLoader]: """Prepare a :py:class:`shifthappens.data.base.DataLoader` based on just the *unlabeled* images which will be passed to the model @@ -263,7 +263,7 @@ def _prepare_dataloader(self) -> Optional[DataLoader]: return None @abstractmethod - def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: + def _evaluate(self, model: sh_models.Model) -> TaskResult: """Evaluate the task and return a dictionary with the calculated metrics. Args: From b6babbd619c356652c3e30f7c024751fedc467bf Mon Sep 17 00:00:00 2001 From: valentyn Date: Mon, 27 Jun 2022 18:37:51 +0200 Subject: [PATCH 05/20] Proposed task in worst_case.py --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b50cb51e..6d015e36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,4 @@ numpy torch torchvision surgeon_pytorch -sklearn tqdm From 5ecc71835f7a369054ba63543cbe2fee34b0d8f4 Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 03:27:44 +0200 Subject: [PATCH 06/20] Proposed task in worst_case.py --- Makefile | 2 +- examples/raccoons_OOD.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 447b0aa4..88e643cf 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ install: build: python3 -m pip install --upgrade build - python3 -m build + python3 -m build upload_test: python3 -m twine upload --repository testpypi dist/* diff --git a/examples/raccoons_OOD.py b/examples/raccoons_OOD.py index 0b3f2781..91574966 100644 --- a/examples/raccoons_OOD.py +++ b/examples/raccoons_OOD.py @@ -8,6 +8,7 @@ The original dataset was collected by Dat Tran for the object detection task and can be found at https://github.com/datitran/raccoon_dataset. """ + import dataclasses import os From 3f0ee0c67416a8b1101bc1b66e64be4129869e4b Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 03:28:43 +0200 Subject: [PATCH 07/20] Proposed task in worst_case.py --- examples/worst_case.py | 318 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 318 insertions(+) create mode 100644 examples/worst_case.py diff --git a/examples/worst_case.py b/examples/worst_case.py new file mode 100644 index 00000000..f6c8b04a --- /dev/null +++ b/examples/worst_case.py @@ -0,0 +1,318 @@ +""" +This task evaluates a set of metrics, mostly related to worst-class performance, as described in +(J. Bitterwolf et al., "Classifiers Should Do Well Even on Their Worst Classes", https://openreview.net/forum?id=QxIXCVYJ2WP). +It is motivated by +(R. Balestriero et al., "The Effects of Regularization and Data Augmentation are Class Dependent", https://arxiv.org/abs/2204.03632) + where the authors note that using only accuracy as a metric is not enough to evaluate + the performance of the classifier, as it must not be the same on all classes/groups. +""" +import argparse +import collections +import itertools +import time +import urllib +import requests +import dataclasses +import os +import pathlib + +import numpy as np + +from shifthappens.data import imagenet as sh_imagenet +from shifthappens import benchmark as sh_benchmark +from shifthappens.models import base as sh_models +from shifthappens.tasks.base import Task +from shifthappens.tasks.metrics import Metric +from shifthappens.tasks.task_result import TaskResult + + +@sh_benchmark.register_task( + name="Worst_case", relative_data_folder="worst_case", standalone=True +) +@dataclasses.dataclass +class WorstCase(Task): + resources = ( + ["worstcase", + "restricted_superclass.csv", + "https://anonymous.4open.science/r/worst_classes-B94C/restricted_superclass.csv", + None], + + ["worstcase", + "new_labels.csv", + "https://anonymous.4open.science/r/worst_classes-B94C/new_labels.csv", + None] + ) + + new_labels = None + new_labels_mask = None + superclasses = None + + verbose = True + labels_type = 'val' + n_retries = 5 + max_batch_size: int = 256 + + def download(self, url, data_folder, filename, md5): + + for _ in range(self.n_retries): + try: + r = requests.get(url) + pathlib.Path(data_folder).mkdir(parents=True, exist_ok=True) + open(os.path.join(data_folder, filename), 'wb').write(r.content) + break + except urllib.error.URLError: + print(f"Download of {url} failed; wait 5s and then try again.") + time.sleep(5) + def setup(self): + + # Download resources + for resource in self.resources: + folder_name, file_name, url, md5 = resource + dataset_folder = os.path.join(self.data_root, folder_name) + if not os.path.isfile(os.path.join(dataset_folder, file_name)): + self.download(url, dataset_folder, file_name, md5) + print(f'File {file_name} is in {dataset_folder}.') + # Set the cleaned labels to a property + new_labels = np.array([int(line) for line in open(os.path.join(dataset_folder, 'new_labels.csv'))]) + + if self.labels_type == 'val_clean': + gooduns = new_labels != -1 + self.new_labels = new_labels[gooduns] + elif self.labels_type == 'val': + gooduns = np.full(new_labels.shape, True) + self.new_labels = np.array(sh_imagenet.load_imagenet_targets()) + + self.new_labels_mask = gooduns + + # Set the superclasses to a property + superclass_list = np.array([int(line) for line in open(os.path.join(dataset_folder, 'restricted_superclass.csv'))]) + self.superclasses = [tuple(np.where(superclass_list == i)[0]) for i in range(0, 9)] + + def get_predictions(self) -> np.ndarray: + preds = { + 'predicted_classes': self.probs.argmax(axis=1), + 'class_probabilities': self.probs, + 'confidences_classifier': self.probs.max(axis=1), + } + preds['number_of_class_predictions'] = collections.Counter(preds['predicted_classes']) + return preds + + def standard_accuracy(self): + preds = self.get_predictions() + accuracy = (preds['predicted_classes'] == self.new_labels).mean() + return accuracy + + def classwise_accuracies(self): + preds = self.get_predictions() + clw_acc = {} + for i in set(self.new_labels): + clw_acc[i] = np.equal(preds['predicted_classes'][np.where(self.new_labels == i)], i).mean() + return clw_acc + + def classwise_sample_numbers(self): + clw_sn = {} + for i in set(self.new_labels): + clw_sn[i] = np.sum(self.new_labels == i) + return clw_sn + + def classwise_topk_accuracies(self, k): + preds = self.get_predictions() + clw_topk_acc = {} + for i in set(self.new_labels): + clw_topk_acc[i] = np.equal(i, np.argsort(preds['class_probabilities'][np.where(self.new_labels == i)], axis=1, kind='mergesort')[:, + -k:]).sum(axis=-1).mean() + return clw_topk_acc + + def standard_balanced_topk_accuracy(self, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + return np.array(list(clw_topk_acc.values())).mean() + + def worst_class_accuracy(self): + cwa = self.classwise_accuracies() + worst_item = min(cwa.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_class_topk_accuracy(self, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + worst_item = min(clw_topk_acc.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_balanced_n_classes_accuracy(self, n): + cwa = self.classwise_accuracies() + sorted_cwa = sorted(cwa.items(), key=lambda item: item[1]) + n_worst = sorted_cwa[:n] + return np.array([x[1] for x in n_worst]).mean() + + def worst_heuristic_n_classes_recall(self, n): + cwa = self.classwise_accuracies() + clw_sn = self.classwise_sample_numbers() + sorted_cwa = sorted(cwa.items(), key=lambda item: item[1]) + n_worst = sorted_cwa[:n] + nwc = np.array([v * clw_sn[c] for c, v in n_worst]).sum() / np.array([clw_sn[c] for c, v in n_worst]).sum() + return nwc + + def worst_balanced_n_classes_topk_accuracy(self, n, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + sorted_clw_topk_acc = sorted(clw_topk_acc.items(), key=lambda item: item[1]) + n_worst = sorted_clw_topk_acc[:n] + return np.array([x[1] for x in n_worst]).mean() + + def worst_heuristic_n_classes_topk_recall(self, n, k): + clw_topk_acc = self.classwise_topk_accuracies(k) + clw_sn = self.classwise_sample_numbers() + sorted_clw_topk_acc = sorted(clw_topk_acc.items(), key=lambda item: item[1]) + n_worst = sorted_clw_topk_acc[:n] + nwc = np.array([v * clw_sn[c] for c, v in n_worst]).sum() / np.array([clw_sn[c] for c, v in n_worst]).sum() + return nwc + + def worst_balanced_two_class_binary_accuracy(self): + classes = list(set(self.new_labels)) + binary_accuracies = {} + for i, j in itertools.combinations(classes, 2): + i_labelled = self.probs[np.where(self.new_labels == i)] + j_labelled = self.probs[np.where(self.new_labels == j)] + i_correct = np.greater(i_labelled[:, i], i_labelled[:, j]).mean() + j_correct = np.greater(j_labelled[:, j], j_labelled[:, i]).mean() + binary_accuracies[(i, j)] = (i_correct + j_correct) / 2 + sorted_binary_accuracies = sorted(binary_accuracies.items(), key=lambda item: item[1]) + worst_item = sorted_binary_accuracies[0] + return worst_item[1] + + def worst_balanced_superclass_recall(self): + cwa = self.classwise_accuracies() + scwa = {i: np.array([cwa[c] for c in s]).mean() for i, s in enumerate(self.superclasses)} + worst_item = min(scwa.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_superclass_recall(self): + cwa = self.classwise_accuracies() + clw_sn = self.classwise_sample_numbers() + scwa = {i: np.array([cwa[c] * clw_sn[c] for c in s]).sum() / np.array([clw_sn[c] for c in s]).sum() for i, s in + enumerate(self.superclasses)} + worst_item = min(scwa.items(), key=lambda x: x[1]) + return worst_item[1] + + def intra_superclass_accuracies(self): + isa = {} + original_probs = self.probs.copy() + original_targets = self.new_labels.copy() + for i, s in enumerate(self.superclasses): + self.probs = original_probs.copy() + self.new_labels = original_targets.copy() + + internal_samples = np.isin(self.new_labels, s) + internal_targets = self.new_labels[internal_samples] + internal_probs = self.probs[internal_samples][:, s] + s_targets = np.vectorize(lambda x: s[x]) + self.probs = internal_probs + self.new_labels = internal_targets + internal_preds = s_targets(self.get_predictions()['predicted_classes']) + isa[i] = (internal_preds == internal_targets).mean() + + self.probs = original_probs + self.new_labels = original_targets + + return isa + + def worst_intra_superclass_accuracy(self): + isa = self.intra_superclass_accuracies() + worst_item = min(isa.items(), key=lambda x: x[1]) + return worst_item[1] + + def worst_class_precision(self): + preds = self.get_predictions() + classes = list(set(self.new_labels)) + sc = {} + for c in classes: + erroneous_c = (preds['predicted_classes'] == c) * (self.new_labels != c) + correct_c = (preds['predicted_classes'] == c) * (self.new_labels == c) + predicted_c = (preds['predicted_classes'] == c) + if predicted_c.sum(): + sc[c] = correct_c.sum() / predicted_c.sum() # 1-erroneous_c.sum()/predicted_c.sum() + else: + sc[c] = 1 + sorted_sc = sorted(sc.items(), key=lambda item: item[1]) + worst_item = sorted_sc[0] + return worst_item[1] + + def class_confusion(self): + preds = self.get_predictions() + classes = list(set(self.new_labels)) + confusion = np.zeros((len(classes), len(classes))) + for i, c in enumerate(self.new_labels): + confusion[c, preds['predicted_classes'][i]] += 1 + return confusion + + + def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: + + verbose = self.verbose + model.verbose = verbose + + if verbose: + print(f'new labels of type {self.labels_type} are', self.new_labels, len(self.new_labels)) + + self.probs = model.imagenet_validation_result.confidences[self.new_labels_mask, :] + + metrics = { + 'A': self.standard_accuracy, + 'WCA': self.worst_class_accuracy, + 'WCP': self.worst_class_precision, + 'WSupCA': self.worst_intra_superclass_accuracy, + 'WSupCR': self.worst_superclass_recall, + 'W10CR': lambda : self.worst_heuristic_n_classes_recall(10), + 'W100CR': lambda : self.worst_heuristic_n_classes_recall(100), + 'W2CA': self.worst_balanced_two_class_binary_accuracy, + 'WCAat5': lambda : self.worst_class_topk_accuracy(5), + 'W10CRat5': lambda : self.worst_heuristic_n_classes_topk_recall(10, 5), + 'W100CRat5': lambda : self.worst_heuristic_n_classes_topk_recall(100, 5), + } + + metrics_eval = {} + for metric_name, metric in metrics.items(): + if verbose: + print(f'Evaluating {metric_name}') + metrics_eval[metric_name] = metric() + if verbose: + print('metrics are', metrics_eval) + return TaskResult( + summary_metrics={Metric.Fairness: ("A", "WCA", "WCP", "WSupCA", "WSupCR", + "W10CR", "W100CR", "W2CA", "WCAat5", + "W10CRat5", "W100CRat5")}, + **metrics_eval + + ) + + +if __name__ == "__main__": + from shifthappens.models.torchvision import ResNet18 + import shifthappens + + parser = argparse.ArgumentParser() + + # Set the label type either to val (50000 labels) or + # val_clean (46044 labels) for the cleaned labels from + # (C. Northcutt et al., "Pervasive Label Errors in Test Sets Destabilize Machine Learning Benchmarks", https://arxiv.org/abs/2103.14749, https://github.com/cleanlab/label-errors) + parser.add_argument( + "--labels_type", type=str, help="The label type", default='val' + ) + parser.add_argument( + "--imagenet_val_folder", type=str, help="The folder for the imagenet val set", required=True + ) + parser.add_argument( + "--verbose", + help="Turn verbose mode on when set", + action="store_true", + ) + + args = parser.parse_args() + + shifthappens.data.imagenet.ImageNetValidationData=args.imagenet_val_folder + + tuple(sh_benchmark.__registered_tasks)[0].cls.verbose = args.verbose + + tuple(sh_benchmark.__registered_tasks)[0].cls.labels_type = args.labels_type + sh_benchmark.evaluate_model( + ResNet18(device="cuda:2", max_batch_size=500), + "data" + ) From 29ed2ae97bfe0cc6dfcd5b7d14a13a7538579c3b Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 04:28:53 +0200 Subject: [PATCH 08/20] Proposed task in worst_case.py --- shifthappens/tasks/worst_case/README.rst | 39 +++++++++++++++++++ shifthappens/tasks/worst_case/__init__.py | 0 .../tasks/worst_case}/worst_case.py | 0 3 files changed, 39 insertions(+) create mode 100644 shifthappens/tasks/worst_case/README.rst create mode 100644 shifthappens/tasks/worst_case/__init__.py rename {examples => shifthappens/tasks/worst_case}/worst_case.py (100%) diff --git a/shifthappens/tasks/worst_case/README.rst b/shifthappens/tasks/worst_case/README.rst new file mode 100644 index 00000000..4adb183e --- /dev/null +++ b/shifthappens/tasks/worst_case/README.rst @@ -0,0 +1,39 @@ +Example for a Shift Happens task on ImageNet +============================================== +# Task Description +This task evaluates a set of metrics, mostly related to worst-class performance, as described in [1]. +It is motivated by [2], where the authors note that using only accuracy as a metric is not enough to evaluate + the performance of the classifier, as it must not be the same on all classes/groups. + +## How to start +run + +``` +python tasks/worst_case/worst_case.py --imagenet_val_folder '/scratch/datasets/imagenet/val' --verbose --labels_type 'val' +``` + +for evaluating with original labels, and + +``` +python tasks/worst_case/worst_case.py --imagenet_val_folder '/scratch/datasets/imagenet/val' --verbose --labels_type 'val_clean' +``` + +for evaluating with cleaned ones from [3]. + + +## Evaluation Metrics +The evaluation metrics are "A", "WCA", "WCP", "WSupCA", "WSupCR", "W10CR", "W100CR", "W2CA", "WCAat5", "W10CRat5", "W100CRat5", and their relevance is described in (J. Bitterwolf et al., "Classifiers Should Do Well Even on Their Worst Classes", https://openreview.net/forum?id=QxIXCVYJ2WP). + +## Expected Insights/Relevance +To see the, how the model performs on its worst classes. The application examples are given in [1]. + + +1. Classifiers Should Do Well Even on Their Worst Classes. + J. Bitterwolf et al. 2022. + +2. The Effects of Regularization and Data Augmentation are Class Dependent. + R. Balestriero et al. 2022. + +3. Pervasive Label Errors in Test Sets Destabilize Machine Learning Benchmarks. + C. Northcutt et al. 2021. + diff --git a/shifthappens/tasks/worst_case/__init__.py b/shifthappens/tasks/worst_case/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py similarity index 100% rename from examples/worst_case.py rename to shifthappens/tasks/worst_case/worst_case.py From 3c8fcbe79a22ece8719bc69c14ce03febe27465f Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 04:36:57 +0200 Subject: [PATCH 09/20] Proposed task in worst_case.py --- shifthappens/tasks/worst_case/worst_case.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py index 38291e95..be3d19cf 100644 --- a/shifthappens/tasks/worst_case/worst_case.py +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -307,10 +307,10 @@ def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: args = parser.parse_args() - shifthappens.data.imagenet.ImageNetValidationData=args.imagenet_val_folder + shifthappens.config.imagenet_validation_path = args.imagenet_val_folder - tuple(sh_benchmark.__registered_tasks)[0].cls.verbose = args.verbose + shifthappens.config.verbose = args.verbose tuple(sh_benchmark.__registered_tasks)[0].cls.labels_type = args.labels_type sh_benchmark.evaluate_model( From a7a967309a762bf350ee51939a397c813daab524 Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 04:57:47 +0200 Subject: [PATCH 10/20] Proposed task in worst_case.py --- shifthappens/tasks/worst_case/README.rst | 6 ++--- shifthappens/tasks/worst_case/worst_case.py | 28 +++++++++++++++++++-- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/shifthappens/tasks/worst_case/README.rst b/shifthappens/tasks/worst_case/README.rst index 4adb183e..5021ffa5 100644 --- a/shifthappens/tasks/worst_case/README.rst +++ b/shifthappens/tasks/worst_case/README.rst @@ -6,16 +6,16 @@ It is motivated by [2], where the authors note that using only accuracy as a met the performance of the classifier, as it must not be the same on all classes/groups. ## How to start -run +in the icml-2022 folder, run ``` -python tasks/worst_case/worst_case.py --imagenet_val_folder '/scratch/datasets/imagenet/val' --verbose --labels_type 'val' +python shifthappens/tasks/worst_case/worst_case.py --imagenet_val_folder '/scratch/datasets/imagenet/val' --verbose --labels_type 'val' ``` for evaluating with original labels, and ``` -python tasks/worst_case/worst_case.py --imagenet_val_folder '/scratch/datasets/imagenet/val' --verbose --labels_type 'val_clean' +python shifthappens/tasks/worst_case/worst_case.py --imagenet_val_folder '/scratch/datasets/imagenet/val' --verbose --labels_type 'val_clean' ``` for evaluating with cleaned ones from [3]. diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py index be3d19cf..6f315e5e 100644 --- a/shifthappens/tasks/worst_case/worst_case.py +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -15,8 +15,9 @@ import dataclasses import os import pathlib - +import torch import numpy as np +import torch.nn as nn from shifthappens.data import imagenet as sh_imagenet from shifthappens import benchmark as sh_benchmark @@ -299,21 +300,44 @@ def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: parser.add_argument( "--imagenet_val_folder", type=str, help="The folder for the imagenet val set", required=True ) + parser.add_argument('--gpu', '--list', nargs='+', default=[0], + help='GPU indices, if more than 1 parallel modules will be called') + parser.add_argument('--bs', type=int, default=500) parser.add_argument( "--verbose", help="Turn verbose mode on when set", action="store_true", ) + args = parser.parse_args() + if len(args.gpu) == 0: + device_ids = None + device = torch.device('cpu') + print('Warning! Computing on CPU') + num_devices = 1 + elif len(args.gpu) == 1: + device_ids = [int(args.gpu[0])] + device = torch.device('cuda:' + str(args.gpu[0])) + num_devices = 1 + else: + device_ids = [int(i) for i in args.gpu] + device = torch.device('cuda:' + str(min(device_ids))) + num_devices = len(device_ids) + shifthappens.config.imagenet_validation_path = args.imagenet_val_folder shifthappens.config.verbose = args.verbose tuple(sh_benchmark.__registered_tasks)[0].cls.labels_type = args.labels_type + + model = ResNet18(device=device, max_batch_size=args.bs) + + if device_ids is not None and len(device_ids) > 1: + model = nn.DataParallel(model, device_ids=device_ids) sh_benchmark.evaluate_model( - ResNet18(device="cuda:2", max_batch_size=500), + model, "data" ) From 9255b75e433546f60a1f6a2233f550ee316cf2c3 Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 05:44:02 +0200 Subject: [PATCH 11/20] Proposed task in worst_case.py --- shifthappens/tasks/worst_case/worst_case.py | 132 +++++++++++--------- 1 file changed, 76 insertions(+), 56 deletions(-) diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py index 6f315e5e..205106de 100644 --- a/shifthappens/tasks/worst_case/worst_case.py +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -1,16 +1,15 @@ -""" -This task evaluates a set of metrics, mostly related to worst-class performance, as described in -(J. Bitterwolf et al., "Classifiers Should Do Well Even on Their Worst Classes", https://openreview.net/forum?id=QxIXCVYJ2WP). -It is motivated by -(R. Balestriero et al., "The Effects of Regularization and Data Augmentation are Class Dependent", https://arxiv.org/abs/2204.03632) - where the authors note that using only accuracy as a metric is not enough to evaluate - the performance of the classifier, as it must not be the same on all classes/groups. +"""Classifiers Should Do Well Even on Their Worst Classes +This task evaluates a set of metrics, mostly related to worst-class performance, as described in [1]. +It is motivated by [2], where the authors note that using only accuracy as a metric is not enough to evaluate +the performance of the classifier, as it must not be the same on all classes/groups. """ import argparse import collections import itertools import time import urllib +from typing import List, Union + import requests import dataclasses import os @@ -18,6 +17,8 @@ import torch import numpy as np import torch.nn as nn +from numpy.core._multiarray_umath import ndarray +from numpy.core.multiarray import ndarray from shifthappens.data import imagenet as sh_imagenet from shifthappens import benchmark as sh_benchmark @@ -45,16 +46,15 @@ class WorstCase(Task): ) new_labels = None - new_labels_mask = None - superclasses = None + new_labels_mask: Union[ndarray, None, bool] = None + superclasses: List[tuple] = None - verbose = True - labels_type = 'val' - n_retries = 5 + verbose: bool = True + labels_type: str = 'val' + n_retries: int = 5 max_batch_size: int = 256 def download(self, url, data_folder, filename, md5): - for _ in range(self.n_retries): try: r = requests.get(url) @@ -64,8 +64,9 @@ def download(self, url, data_folder, filename, md5): except urllib.error.URLError: print(f"Download of {url} failed; wait 5s and then try again.") time.sleep(5) + def setup(self): - + """Downloads the cleaned labels from [3], as well as superclasses used in [1]""" # Download resources for resource in self.resources: folder_name, file_name, url, md5 = resource @@ -74,22 +75,23 @@ def setup(self): self.download(url, dataset_folder, file_name, md5) print(f'File {file_name} is in {dataset_folder}.') # Set the cleaned labels to a property - new_labels = np.array([int(line) for line in open(os.path.join(dataset_folder, 'new_labels.csv'))]) + new_labels: ndarray = np.array([int(line) for line in open(os.path.join(dataset_folder, 'new_labels.csv'))]) if self.labels_type == 'val_clean': - gooduns = new_labels != -1 - self.new_labels = new_labels[gooduns] + cleaned_labels = new_labels != -1 + self.new_labels = new_labels[cleaned_labels] elif self.labels_type == 'val': - gooduns = np.full(new_labels.shape, True) + cleaned_labels = np.full(new_labels.shape, True) self.new_labels = np.array(sh_imagenet.load_imagenet_targets()) - self.new_labels_mask = gooduns + self.new_labels_mask = cleaned_labels # Set the superclasses to a property - superclass_list = np.array([int(line) for line in open(os.path.join(dataset_folder, 'restricted_superclass.csv'))]) + superclass_list: ndarray = np.array([int(line) for line in open(os.path.join(dataset_folder, 'restricted_superclass.csv'))]) self.superclasses = [tuple(np.where(superclass_list == i)[0]) for i in range(0, 9)] def get_predictions(self) -> np.ndarray: + """Saves to a property as a dict the computed predictions and probabilities for the used model""" preds = { 'predicted_classes': self.probs.argmax(axis=1), 'class_probabilities': self.probs, @@ -99,11 +101,13 @@ def get_predictions(self) -> np.ndarray: return preds def standard_accuracy(self): + """Computes standard accuracy""" preds = self.get_predictions() accuracy = (preds['predicted_classes'] == self.new_labels).mean() return accuracy def classwise_accuracies(self): + """Computes accuracies per each class""" preds = self.get_predictions() clw_acc = {} for i in set(self.new_labels): @@ -111,62 +115,72 @@ def classwise_accuracies(self): return clw_acc def classwise_sample_numbers(self): - clw_sn = {} + """Computes number of samples per class""" + classwise_sample_number = {} for i in set(self.new_labels): - clw_sn[i] = np.sum(self.new_labels == i) - return clw_sn + classwise_sample_number[i] = np.sum(self.new_labels == i) + return classwise_sample_number def classwise_topk_accuracies(self, k): + """Computes topk accuracies per class""" preds = self.get_predictions() - clw_topk_acc = {} + classwise_topk_acc = {} for i in set(self.new_labels): - clw_topk_acc[i] = np.equal(i, np.argsort(preds['class_probabilities'][np.where(self.new_labels == i)], axis=1, kind='mergesort')[:, + classwise_topk_acc[i] = np.equal(i, np.argsort(preds['class_probabilities'][np.where(self.new_labels == i)], axis=1, kind='mergesort')[:, -k:]).sum(axis=-1).mean() - return clw_topk_acc + return classwise_topk_acc def standard_balanced_topk_accuracy(self, k): - clw_topk_acc = self.classwise_topk_accuracies(k) - return np.array(list(clw_topk_acc.values())).mean() + """Computes the balanced topk accuracy""" + classwise_topk_acc = self.classwise_topk_accuracies(k) + return np.array(list(classwise_topk_acc.values())).mean() def worst_class_accuracy(self): - cwa = self.classwise_accuracies() - worst_item = min(cwa.items(), key=lambda x: x[1]) + """Computes the smallest accuracy among classes""" + classwise_accuracies = self.classwise_accuracies() + worst_item = min(classwise_accuracies.items(), key=lambda x: x[1]) return worst_item[1] def worst_class_topk_accuracy(self, k): - clw_topk_acc = self.classwise_topk_accuracies(k) - worst_item = min(clw_topk_acc.items(), key=lambda x: x[1]) + """Computes the smallest topk accuracy among classes""" + classwise_topk_acc = self.classwise_topk_accuracies(k) + worst_item = min(classwise_topk_acc.items(), key=lambda x: x[1]) return worst_item[1] def worst_balanced_n_classes_accuracy(self, n): - cwa = self.classwise_accuracies() - sorted_cwa = sorted(cwa.items(), key=lambda item: item[1]) - n_worst = sorted_cwa[:n] + """Computes the ballanced accuracy among the worst n classes, based on their per-class accuracies""" + classwise_accuracies = self.classwise_accuracies() + sorted_classwise_accuracies = sorted(classwise_accuracies.items(), key=lambda item: item[1]) + n_worst = sorted_classwise_accuracies[:n] return np.array([x[1] for x in n_worst]).mean() def worst_heuristic_n_classes_recall(self, n): - cwa = self.classwise_accuracies() - clw_sn = self.classwise_sample_numbers() - sorted_cwa = sorted(cwa.items(), key=lambda item: item[1]) - n_worst = sorted_cwa[:n] - nwc = np.array([v * clw_sn[c] for c, v in n_worst]).sum() / np.array([clw_sn[c] for c, v in n_worst]).sum() - return nwc + """Computes recall for n worst in terms of their per class accuracy""" + classwise_accuracies = self.classwise_accuracies() + classwise_accuracies_sample_numbers = self.classwise_sample_numbers() + sorted_classwise_accuracies = sorted(classwise_accuracies.items(), key=lambda item: item[1]) + n_worst = sorted_classwise_accuracies[:n] + n_worstclass_recall = np.array([v * classwise_accuracies_sample_numbers[c] for c, v in n_worst]).sum() / np.array([classwise_accuracies_sample_numbers[c] for c, v in n_worst]).sum() + return n_worstclass_recall def worst_balanced_n_classes_topk_accuracy(self, n, k): - clw_topk_acc = self.classwise_topk_accuracies(k) - sorted_clw_topk_acc = sorted(clw_topk_acc.items(), key=lambda item: item[1]) + """Computes the balanced accuracy for the worst n classes in therms of their per class topk accuracy""" + classwise_topk_accuracies = self.classwise_topk_accuracies(k) + sorted_clw_topk_acc = sorted(classwise_topk_accuracies.items(), key=lambda item: item[1]) n_worst = sorted_clw_topk_acc[:n] return np.array([x[1] for x in n_worst]).mean() def worst_heuristic_n_classes_topk_recall(self, n, k): - clw_topk_acc = self.classwise_topk_accuracies(k) + """Computes the recall for the worst n classes in therms of their per class topk accuracy""" + classwise_topk_accuracies = self.classwise_topk_accuracies(k) clw_sn = self.classwise_sample_numbers() - sorted_clw_topk_acc = sorted(clw_topk_acc.items(), key=lambda item: item[1]) + sorted_clw_topk_acc = sorted(classwise_topk_accuracies.items(), key=lambda item: item[1]) n_worst = sorted_clw_topk_acc[:n] - nwc = np.array([v * clw_sn[c] for c, v in n_worst]).sum() / np.array([clw_sn[c] for c, v in n_worst]).sum() - return nwc + n_worstclass_recall = np.array([v * clw_sn[c] for c, v in n_worst]).sum() / np.array([clw_sn[c] for c, v in n_worst]).sum() + return n_worstclass_recall def worst_balanced_two_class_binary_accuracy(self): + """Computes the smallest two-class accuracy, when restricting the classifier to any two classes""" classes = list(set(self.new_labels)) binary_accuracies = {} for i, j in itertools.combinations(classes, 2): @@ -180,20 +194,23 @@ def worst_balanced_two_class_binary_accuracy(self): return worst_item[1] def worst_balanced_superclass_recall(self): - cwa = self.classwise_accuracies() - scwa = {i: np.array([cwa[c] for c in s]).mean() for i, s in enumerate(self.superclasses)} - worst_item = min(scwa.items(), key=lambda x: x[1]) + """Computes the worst balanced recall among the superclasses""" + classwise_accuracies = self.classwise_accuracies() + superclass_classwise_accuracies = {i: np.array([classwise_accuracies[c] for c in s]).mean() for i, s in enumerate(self.superclasses)} + worst_item = min(superclass_classwise_accuracies.items(), key=lambda x: x[1]) return worst_item[1] def worst_superclass_recall(self): - cwa = self.classwise_accuracies() - clw_sn = self.classwise_sample_numbers() - scwa = {i: np.array([cwa[c] * clw_sn[c] for c in s]).sum() / np.array([clw_sn[c] for c in s]).sum() for i, s in + """Computes the worst not balanced recall among the superclasses""" + classwise_accuracies = self.classwise_accuracies() + classwise_sample_number = self.classwise_sample_numbers() + superclass_classwise_accuracies = {i: np.array([classwise_accuracies[c] * classwise_sample_number[c] for c in s]).sum() / np.array([classwise_sample_number[c] for c in s]).sum() for i, s in enumerate(self.superclasses)} - worst_item = min(scwa.items(), key=lambda x: x[1]) + worst_item = min(superclass_classwise_accuracies.items(), key=lambda x: x[1]) return worst_item[1] def intra_superclass_accuracies(self): + """Computes the accuracy for the images among one superclass, for each superclass""" isa = {} original_probs = self.probs.copy() original_targets = self.new_labels.copy() @@ -216,11 +233,13 @@ def intra_superclass_accuracies(self): return isa def worst_intra_superclass_accuracy(self): + """Computes the worst superclass accuracy using intra_superclass_accuracies""" isa = self.intra_superclass_accuracies() worst_item = min(isa.items(), key=lambda x: x[1]) return worst_item[1] def worst_class_precision(self): + """Computes the precision for the worst class""" preds = self.get_predictions() classes = list(set(self.new_labels)) sc = {} @@ -237,6 +256,7 @@ def worst_class_precision(self): return worst_item[1] def class_confusion(self): + """Computes the confision matrix""" preds = self.get_predictions() classes = list(set(self.new_labels)) confusion = np.zeros((len(classes), len(classes))) @@ -246,7 +266,7 @@ def class_confusion(self): def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: - + """The final method that uses all of the above to compute the metrics introduced in [1]""" verbose = self.verbose model.verbose = verbose @@ -293,7 +313,7 @@ def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: # Set the label type either to val (50000 labels) or # val_clean (46044 labels) for the cleaned labels from - # (C. Northcutt et al., "Pervasive Label Errors in Test Sets Destabilize Machine Learning Benchmarks", https://arxiv.org/abs/2103.14749, https://github.com/cleanlab/label-errors) + # [3] parser.add_argument( "--labels_type", type=str, help="The label type", default='val' ) From 216763754671defc9fe151ebb2932b5940022c82 Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 05:53:38 +0200 Subject: [PATCH 12/20] Proposed task in worst_case.py --- shifthappens/tasks/worst_case/worst_case.py | 27 +++++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py index 205106de..02ea5b03 100644 --- a/shifthappens/tasks/worst_case/worst_case.py +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -233,30 +233,41 @@ def intra_superclass_accuracies(self): return isa def worst_intra_superclass_accuracy(self): - """Computes the worst superclass accuracy using intra_superclass_accuracies""" + """Computes the worst superclass accuracy using intra_superclass_accuracies + + Output: the accuracy for the worst super class + """ isa = self.intra_superclass_accuracies() worst_item = min(isa.items(), key=lambda x: x[1]) return worst_item[1] def worst_class_precision(self): - """Computes the precision for the worst class""" + """Computes the precision for the worst class + + Returns: + Dict entry with the worst performing class + """ preds = self.get_predictions() classes = list(set(self.new_labels)) - sc = {} + per_class_precision = {} for c in classes: erroneous_c = (preds['predicted_classes'] == c) * (self.new_labels != c) correct_c = (preds['predicted_classes'] == c) * (self.new_labels == c) predicted_c = (preds['predicted_classes'] == c) if predicted_c.sum(): - sc[c] = correct_c.sum() / predicted_c.sum() # 1-erroneous_c.sum()/predicted_c.sum() + per_class_precision[c] = correct_c.sum() / predicted_c.sum() # 1-erroneous_c.sum()/predicted_c.sum() else: - sc[c] = 1 - sorted_sc = sorted(sc.items(), key=lambda item: item[1]) + per_class_precision[c] = 1 + sorted_sc = sorted(per_class_precision.items(), key=lambda item: item[1]) worst_item = sorted_sc[0] return worst_item[1] - def class_confusion(self): - """Computes the confision matrix""" + def class_confusion(self) -> np.array: + """Computes the confision matrix + Returns: + confusion: confusion matrx + + """ preds = self.get_predictions() classes = list(set(self.new_labels)) confusion = np.zeros((len(classes), len(classes))) From 19217eaa69ad2c8443e82acc35dc7a9dc0411de5 Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 05:58:00 +0200 Subject: [PATCH 13/20] Proposed task in worst_case.py --- shifthappens/tasks/worst_case/worst_case.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py index 02ea5b03..b293491c 100644 --- a/shifthappens/tasks/worst_case/worst_case.py +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -106,7 +106,7 @@ def standard_accuracy(self): accuracy = (preds['predicted_classes'] == self.new_labels).mean() return accuracy - def classwise_accuracies(self): + def classwise_accuracies(self) -> dict: """Computes accuracies per each class""" preds = self.get_predictions() clw_acc = {} @@ -114,14 +114,14 @@ def classwise_accuracies(self): clw_acc[i] = np.equal(preds['predicted_classes'][np.where(self.new_labels == i)], i).mean() return clw_acc - def classwise_sample_numbers(self): + def classwise_sample_numbers(self) -> dict: """Computes number of samples per class""" classwise_sample_number = {} for i in set(self.new_labels): classwise_sample_number[i] = np.sum(self.new_labels == i) return classwise_sample_number - def classwise_topk_accuracies(self, k): + def classwise_topk_accuracies(self, k) -> dict: """Computes topk accuracies per class""" preds = self.get_predictions() classwise_topk_acc = {} @@ -130,24 +130,24 @@ def classwise_topk_accuracies(self, k): -k:]).sum(axis=-1).mean() return classwise_topk_acc - def standard_balanced_topk_accuracy(self, k): + def standard_balanced_topk_accuracy(self, k) -> np.array: """Computes the balanced topk accuracy""" classwise_topk_acc = self.classwise_topk_accuracies(k) return np.array(list(classwise_topk_acc.values())).mean() - def worst_class_accuracy(self): + def worst_class_accuracy(self) -> float: """Computes the smallest accuracy among classes""" classwise_accuracies = self.classwise_accuracies() worst_item = min(classwise_accuracies.items(), key=lambda x: x[1]) return worst_item[1] - def worst_class_topk_accuracy(self, k): + def worst_class_topk_accuracy(self, k) -> float: """Computes the smallest topk accuracy among classes""" classwise_topk_acc = self.classwise_topk_accuracies(k) worst_item = min(classwise_topk_acc.items(), key=lambda x: x[1]) return worst_item[1] - def worst_balanced_n_classes_accuracy(self, n): + def worst_balanced_n_classes_accuracy(self, n) -> np.array: """Computes the ballanced accuracy among the worst n classes, based on their per-class accuracies""" classwise_accuracies = self.classwise_accuracies() sorted_classwise_accuracies = sorted(classwise_accuracies.items(), key=lambda item: item[1]) From 8d46e340836e4452136ff04ea2a93bcd7f455917 Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 06:01:33 +0200 Subject: [PATCH 14/20] Proposed task in worst_case.py --- shifthappens/tasks/worst_case/worst_case.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py index b293491c..abaa8504 100644 --- a/shifthappens/tasks/worst_case/worst_case.py +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -55,6 +55,7 @@ class WorstCase(Task): max_batch_size: int = 256 def download(self, url, data_folder, filename, md5): + """Method to download the data given its' url, and the desired folder to stor int""" for _ in range(self.n_retries): try: r = requests.get(url) @@ -66,7 +67,7 @@ def download(self, url, data_folder, filename, md5): time.sleep(5) def setup(self): - """Downloads the cleaned labels from [3], as well as superclasses used in [1]""" + """Calls the download method to download the cleaned labels from [3], as well as superclasses used in [1]""" # Download resources for resource in self.resources: folder_name, file_name, url, md5 = resource From 0b856c147996da29ecec75303115511455d761ae Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 06:02:41 +0200 Subject: [PATCH 15/20] Proposed task in worst_case.py --- shifthappens/tasks/worst_case/worst_case.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py index abaa8504..76b33d84 100644 --- a/shifthappens/tasks/worst_case/worst_case.py +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -1,8 +1,4 @@ -"""Classifiers Should Do Well Even on Their Worst Classes -This task evaluates a set of metrics, mostly related to worst-class performance, as described in [1]. -It is motivated by [2], where the authors note that using only accuracy as a metric is not enough to evaluate -the performance of the classifier, as it must not be the same on all classes/groups. -""" +"""Classifiers Should Do Well Even on Their Worst Classes""" import argparse import collections import itertools @@ -33,6 +29,9 @@ ) @dataclasses.dataclass class WorstCase(Task): + """This task evaluates a set of metrics, mostly related to worst-class performance, as described in [1]. + It is motivated by [2], where the authors note that using only accuracy as a metric is not enough to evaluate + the performance of the classifier, as it must not be the same on all classes/groups.""" resources = ( ["worstcase", "restricted_superclass.csv", From 5c32a7a3d39a8d83cc4e54fe72b04d0f1db20a02 Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 06:05:18 +0200 Subject: [PATCH 16/20] Proposed task in worst_case.py --- shifthappens/tasks/worst_case/worst_case.py | 26 ++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py index 76b33d84..2dc97925 100644 --- a/shifthappens/tasks/worst_case/worst_case.py +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -100,7 +100,7 @@ def get_predictions(self) -> np.ndarray: preds['number_of_class_predictions'] = collections.Counter(preds['predicted_classes']) return preds - def standard_accuracy(self): + def standard_accuracy(self) -> np.float: """Computes standard accuracy""" preds = self.get_predictions() accuracy = (preds['predicted_classes'] == self.new_labels).mean() @@ -154,7 +154,7 @@ def worst_balanced_n_classes_accuracy(self, n) -> np.array: n_worst = sorted_classwise_accuracies[:n] return np.array([x[1] for x in n_worst]).mean() - def worst_heuristic_n_classes_recall(self, n): + def worst_heuristic_n_classes_recall(self, n) -> np.float: """Computes recall for n worst in terms of their per class accuracy""" classwise_accuracies = self.classwise_accuracies() classwise_accuracies_sample_numbers = self.classwise_sample_numbers() @@ -163,14 +163,14 @@ def worst_heuristic_n_classes_recall(self, n): n_worstclass_recall = np.array([v * classwise_accuracies_sample_numbers[c] for c, v in n_worst]).sum() / np.array([classwise_accuracies_sample_numbers[c] for c, v in n_worst]).sum() return n_worstclass_recall - def worst_balanced_n_classes_topk_accuracy(self, n, k): + def worst_balanced_n_classes_topk_accuracy(self, n, k) -> np.float: """Computes the balanced accuracy for the worst n classes in therms of their per class topk accuracy""" classwise_topk_accuracies = self.classwise_topk_accuracies(k) sorted_clw_topk_acc = sorted(classwise_topk_accuracies.items(), key=lambda item: item[1]) n_worst = sorted_clw_topk_acc[:n] return np.array([x[1] for x in n_worst]).mean() - def worst_heuristic_n_classes_topk_recall(self, n, k): + def worst_heuristic_n_classes_topk_recall(self, n, k) -> np.float: """Computes the recall for the worst n classes in therms of their per class topk accuracy""" classwise_topk_accuracies = self.classwise_topk_accuracies(k) clw_sn = self.classwise_sample_numbers() @@ -179,7 +179,7 @@ def worst_heuristic_n_classes_topk_recall(self, n, k): n_worstclass_recall = np.array([v * clw_sn[c] for c, v in n_worst]).sum() / np.array([clw_sn[c] for c, v in n_worst]).sum() return n_worstclass_recall - def worst_balanced_two_class_binary_accuracy(self): + def worst_balanced_two_class_binary_accuracy(self) -> np.float: """Computes the smallest two-class accuracy, when restricting the classifier to any two classes""" classes = list(set(self.new_labels)) binary_accuracies = {} @@ -193,14 +193,14 @@ def worst_balanced_two_class_binary_accuracy(self): worst_item = sorted_binary_accuracies[0] return worst_item[1] - def worst_balanced_superclass_recall(self): + def worst_balanced_superclass_recall(self) -> np.float: """Computes the worst balanced recall among the superclasses""" classwise_accuracies = self.classwise_accuracies() superclass_classwise_accuracies = {i: np.array([classwise_accuracies[c] for c in s]).mean() for i, s in enumerate(self.superclasses)} worst_item = min(superclass_classwise_accuracies.items(), key=lambda x: x[1]) return worst_item[1] - def worst_superclass_recall(self): + def worst_superclass_recall(self) -> np.float: """Computes the worst not balanced recall among the superclasses""" classwise_accuracies = self.classwise_accuracies() classwise_sample_number = self.classwise_sample_numbers() @@ -209,9 +209,9 @@ def worst_superclass_recall(self): worst_item = min(superclass_classwise_accuracies.items(), key=lambda x: x[1]) return worst_item[1] - def intra_superclass_accuracies(self): + def intra_superclass_accuracies(self) -> dict: """Computes the accuracy for the images among one superclass, for each superclass""" - isa = {} + intra_superclass_accuracies = {} original_probs = self.probs.copy() original_targets = self.new_labels.copy() for i, s in enumerate(self.superclasses): @@ -225,14 +225,14 @@ def intra_superclass_accuracies(self): self.probs = internal_probs self.new_labels = internal_targets internal_preds = s_targets(self.get_predictions()['predicted_classes']) - isa[i] = (internal_preds == internal_targets).mean() + intra_superclass_accuracies[i] = (internal_preds == internal_targets).mean() self.probs = original_probs self.new_labels = original_targets - return isa + return intra_superclass_accuracies - def worst_intra_superclass_accuracy(self): + def worst_intra_superclass_accuracy(self) -> np.float: """Computes the worst superclass accuracy using intra_superclass_accuracies Output: the accuracy for the worst super class @@ -241,7 +241,7 @@ def worst_intra_superclass_accuracy(self): worst_item = min(isa.items(), key=lambda x: x[1]) return worst_item[1] - def worst_class_precision(self): + def worst_class_precision(self) -> np.float: """Computes the precision for the worst class Returns: From 673fec57c834d3045bf46de154987e1f77dfd309 Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 10:49:24 +0200 Subject: [PATCH 17/20] Proposed task in worst_case.py --- shifthappens/tasks/worst_case/worst_case.py | 1 + 1 file changed, 1 insertion(+) diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py index 2dc97925..aee82b65 100644 --- a/shifthappens/tasks/worst_case/worst_case.py +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -14,6 +14,7 @@ import numpy as np import torch.nn as nn from numpy.core._multiarray_umath import ndarray +import shifthappens.config from numpy.core.multiarray import ndarray from shifthappens.data import imagenet as sh_imagenet From b8659be2b7c89695734ad88f89fe4494f498f3cc Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 10:58:16 +0200 Subject: [PATCH 18/20] Proposed task in worst_case.py --- shifthappens/tasks/worst_case/worst_case.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py index aee82b65..7995e687 100644 --- a/shifthappens/tasks/worst_case/worst_case.py +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -318,9 +318,14 @@ def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: if __name__ == "__main__": - from shifthappens.models.torchvision import ResNet18 + from shifthappens.models.torchvision import * import shifthappens + + available_models_dict = {'resnet18': ResNet18, + 'resnet50': ResNet50, + 'vgg16': VGG16} + parser = argparse.ArgumentParser() # Set the label type either to val (50000 labels) or @@ -332,6 +337,10 @@ def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: parser.add_argument( "--imagenet_val_folder", type=str, help="The folder for the imagenet val set", required=True ) + parser.add_argument( + "--model_name", type=str, default='resnet18', + help=f'The name of the model to test. Should be in {available_models_dict.keys()}' + ) parser.add_argument('--gpu', '--list', nargs='+', default=[0], help='GPU indices, if more than 1 parallel modules will be called') parser.add_argument('--bs', type=int, default=500) @@ -341,7 +350,6 @@ def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: action="store_true", ) - args = parser.parse_args() if len(args.gpu) == 0: @@ -365,7 +373,9 @@ def _evaluate(self, model: sh_models.Model, verbose=False) -> TaskResult: tuple(sh_benchmark.__registered_tasks)[0].cls.labels_type = args.labels_type - model = ResNet18(device=device, max_batch_size=args.bs) + assert args.model_name.lower() in available_models_dict, f"Selected model_name should be in {available_models_dict.keys()}" + + model = available_models_dict[args.model_name.lower()](device=device, max_batch_size=args.bs) if device_ids is not None and len(device_ids) > 1: model = nn.DataParallel(model, device_ids=device_ids) From 2efdde82116526e1b26c38e93eac7cddfc850f59 Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 11:16:22 +0200 Subject: [PATCH 19/20] Latest. --- shifthappens/models/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/shifthappens/models/base.py b/shifthappens/models/base.py index 8ec1635c..fa300a5d 100644 --- a/shifthappens/models/base.py +++ b/shifthappens/models/base.py @@ -11,6 +11,7 @@ import abc import dataclasses from typing import Iterator +import shifthappens.config import numpy as np from tqdm import tqdm From ed1ec54369273b8640d66af84ef77a34fce39ed6 Mon Sep 17 00:00:00 2001 From: valentyn Date: Tue, 19 Jul 2022 11:19:00 +0200 Subject: [PATCH 20/20] Added verbose. --- shifthappens/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shifthappens/models/base.py b/shifthappens/models/base.py index fa300a5d..f706668e 100644 --- a/shifthappens/models/base.py +++ b/shifthappens/models/base.py @@ -209,7 +209,7 @@ def _predict_imagenet_val(self): ] } - if self.verbose: + if shifthappens.config.verbose: pred_loader = tqdm(self._predict(imagenet_val_dataloader, targets), desc='Predictions', total=int(len(imagenet_val_dataloader._dataset)/imagenet_val_dataloader.max_batch_size)) else: pred_loader = self._predict(imagenet_val_dataloader, targets)