From 88f0e96d99a3ae0e2d53459a3eafbca5bead1ea8 Mon Sep 17 00:00:00 2001 From: Yannick Prudent Date: Mon, 25 Mar 2024 14:35:32 +0100 Subject: [PATCH 1/5] feat: add verbose option to feature extractors --- oodeel/extractor/feature_extractor.py | 2 ++ oodeel/extractor/keras_feature_extractor.py | 5 ++++- oodeel/extractor/torch_feature_extractor.py | 5 ++++- oodeel/methods/base.py | 20 ++++++++++++++++---- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/oodeel/extractor/feature_extractor.py b/oodeel/extractor/feature_extractor.py index d03a6b2c..e35de0a9 100644 --- a/oodeel/extractor/feature_extractor.py +++ b/oodeel/extractor/feature_extractor.py @@ -114,6 +114,7 @@ def predict( self, dataset: Union[ItemType, DatasetType], postproc_fns: Optional[List[Callable]] = None, + verbose: bool = False, **kwargs, ) -> Tuple[List[TensorType], dict]: """Get the projection of the dataset in the feature space of self.model @@ -122,6 +123,7 @@ def predict( dataset (Union[ItemType, DatasetType]): input dataset postproc_fns (Optional[Callable]): postprocessing function to apply to each feature immediately after forward. Default to None. + verbose (bool): if True, display a progress bar. Defaults to False. kwargs (dict): additional arguments not considered for prediction Returns: diff --git a/oodeel/extractor/keras_feature_extractor.py b/oodeel/extractor/keras_feature_extractor.py index 58af0d34..d3fdf353 100644 --- a/oodeel/extractor/keras_feature_extractor.py +++ b/oodeel/extractor/keras_feature_extractor.py @@ -24,6 +24,7 @@ from typing import Optional import tensorflow as tf +from tqdm import tqdm from ..datasets.tf_data_handler import TFDataHandler from ..types import Callable @@ -190,6 +191,7 @@ def predict( self, dataset: Union[ItemType, tf.data.Dataset], postproc_fns: Optional[List[Callable]] = None, + verbose: bool = False, **kwargs, ) -> Tuple[List[tf.Tensor], dict]: """Get the projection of the dataset in the feature space of self.model @@ -198,6 +200,7 @@ def predict( dataset (Union[ItemType, tf.data.Dataset]): input dataset postproc_fns (Optional[Callable]): postprocessing function to apply to each feature immediately after forward. Default to None. + verbose (bool): if True, display a progress bar. Defaults to False. kwargs (dict): additional arguments not considered for prediction Returns: @@ -218,7 +221,7 @@ def predict( features = [None for i in range(len(self.feature_layers_id))] logits = None contains_labels = TFDataHandler.get_item_length(dataset) > 1 - for elem in dataset: + for elem in tqdm(dataset, desc="Predicting", disable=not verbose): tensor = TFDataHandler.get_input_from_dataset_item(elem) features_batch, logits_batch = self.predict_tensor(tensor, postproc_fns) diff --git a/oodeel/extractor/torch_feature_extractor.py b/oodeel/extractor/torch_feature_extractor.py index 350e291b..2672bf14 100644 --- a/oodeel/extractor/torch_feature_extractor.py +++ b/oodeel/extractor/torch_feature_extractor.py @@ -27,6 +27,7 @@ import torch from torch import nn from torch.utils.data import DataLoader +from tqdm import tqdm from ..datasets.torch_data_handler import TorchDataHandler from ..types import Callable @@ -226,6 +227,7 @@ def predict( dataset: Union[DataLoader, ItemType], postproc_fns: Optional[List[Callable]] = None, detach: bool = True, + verbose: bool = False, **kwargs, ) -> Tuple[List[torch.Tensor], dict]: """Get the projection of the dataset in the feature space of self.model @@ -236,6 +238,7 @@ def predict( each feature immediately after forward. Default to None. detach (bool): if True, return features detached from the computational graph. Defaults to True. + verbose (bool): if True, display a progress bar. Defaults to False. kwargs (dict): additional arguments not considered for prediction Returns: @@ -257,7 +260,7 @@ def predict( logits = None batch = next(iter(dataset)) contains_labels = isinstance(batch, (list, tuple)) and len(batch) > 1 - for elem in dataset: + for elem in tqdm(dataset, desc="Predicting", disable=not verbose): tensor = TorchDataHandler.get_input_from_dataset_item(elem) features_batch, logits_batch = self.predict_tensor( tensor, postproc_fns, detach=detach diff --git a/oodeel/methods/base.py b/oodeel/methods/base.py index 14a4cc15..718b6d70 100644 --- a/oodeel/methods/base.py +++ b/oodeel/methods/base.py @@ -20,11 +20,13 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import inspect from abc import ABC from abc import abstractmethod from typing import get_args import numpy as np +from tqdm import tqdm from ..extractor.feature_extractor import FeatureExtractor from ..types import Callable @@ -104,6 +106,7 @@ def fit( fit_dataset: Optional[Union[ItemType, DatasetType]] = None, feature_layers_id: List[Union[int, str]] = [], input_layer_id: Optional[Union[int, str]] = None, + verbose: bool = False, **kwargs, ) -> None: """Prepare the detector for scoring: @@ -122,6 +125,7 @@ def fit( layer of the feature extractor. If int, the rank of the layer in the layer list If str, the name of the layer. Defaults to None. + verbose (bool): if True, display a progress bar. Defaults to False. """ ( self.backend, @@ -144,7 +148,7 @@ def fit( " provided to compute react activation threshold" ) else: - self.compute_react_threshold(model, fit_dataset) + self.compute_react_threshold(model, fit_dataset, verbose=verbose) if (feature_layers_id == []) and (self.requires_internal_features): raise ValueError( @@ -160,6 +164,8 @@ def fit( ) if fit_dataset is not None: + if "verbose" in inspect.signature(self._fit_to_dataset).parameters.keys(): + kwargs.update({"verbose": verbose}) self._fit_to_dataset(fit_dataset, **kwargs) def _load_feature_extractor( @@ -207,12 +213,14 @@ def _fit_to_dataset(self, fit_dataset: DatasetType) -> None: def score( self, dataset: Union[ItemType, DatasetType], + verbose: bool = False, ) -> np.ndarray: """ Computes an OOD score for input samples "inputs". Args: dataset (Union[ItemType, DatasetType]): dataset or tensors to score + verbose (bool): if True, display a progress bar. Defaults to False. Returns: tuple: scores or list of scores (depending on the input) and a dictionary @@ -236,7 +244,7 @@ def score( scores = np.array([]) logits = None - for item in dataset: + for item in tqdm(dataset, desc="Scoring", disable=not verbose): tensor = self.data_handler.get_input_from_dataset_item(item) score_batch = self._score_tensor(tensor) logits_batch = self.op.convert_to_numpy( @@ -267,9 +275,13 @@ def score( info = dict(labels=labels, logits=logits) return scores, info - def compute_react_threshold(self, model: Callable, fit_dataset: DatasetType): + def compute_react_threshold( + self, model: Callable, fit_dataset: DatasetType, verbose: bool = False + ): penult_feat_extractor = self._load_feature_extractor(model, [-2]) - unclipped_features, _ = penult_feat_extractor.predict(fit_dataset) + unclipped_features, _ = penult_feat_extractor.predict( + fit_dataset, verbose=verbose + ) self.react_threshold = self.op.quantile( unclipped_features[0], self.react_quantile ) From 16b76d6cd2da4fc5a69fea688a4d5d672a0e2c0b Mon Sep 17 00:00:00 2001 From: Yannick Prudent Date: Mon, 25 Mar 2024 14:40:40 +0100 Subject: [PATCH 2/5] fix: mahalanobis * detach the features at the fit to save VRAM (torch) * compute covariance using accumulated mean to save VRAM * fix Mahalanobis covariance computation (should not be class conditional) --- oodeel/methods/mahalanobis.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/oodeel/methods/mahalanobis.py b/oodeel/methods/mahalanobis.py index a113b9a4..e32d2158 100644 --- a/oodeel/methods/mahalanobis.py +++ b/oodeel/methods/mahalanobis.py @@ -55,7 +55,7 @@ def _fit_to_dataset(self, fit_dataset: DatasetType) -> None: fit_dataset (Union[TensorType, DatasetType]): input dataset (ID) """ # extract features and labels - features, infos = self.feature_extractor.predict(fit_dataset) + features, infos = self.feature_extractor.predict(fit_dataset, detach=True) labels = infos["labels"] # unique sorted classes @@ -63,22 +63,24 @@ def _fit_to_dataset(self, fit_dataset: DatasetType) -> None: # compute mus and covs mus = dict() - covs = dict() + mean_cov = None for cls in self._classes: indexes = self.op.equal(labels, cls) _features_cls = self.op.flatten(features[0][indexes]) mus[cls] = self.op.mean(_features_cls, dim=0) _zero_f_cls = _features_cls - mus[cls] - covs[cls] = ( + cov_cls = ( self.op.matmul(self.op.t(_zero_f_cls), _zero_f_cls) / _zero_f_cls.shape[0] ) + if mean_cov is None: + mean_cov = (len(_features_cls) / len(features)) * cov_cls + else: + mean_cov += (len(_features_cls) / len(features)) * cov_cls - # mean cov and its inverse - mean_cov = self.op.mean(self.op.stack(list(covs.values())), dim=0) - - self._mus = mus + # pseudo-inverse of the mean covariance matrix self._pinv_cov = self.op.pinv(mean_cov) + self._mus = mus def _score_tensor(self, inputs: TensorType) -> Tuple[np.ndarray]: """ From 217dc281b19f0526f32b19f285b0f5d9f99c8035 Mon Sep 17 00:00:00 2001 From: Yannick Prudent Date: Mon, 25 Mar 2024 14:43:28 +0100 Subject: [PATCH 3/5] fix: set num_workers to 8 for torch dataloaders (was equal to 0 so reaaally slow!) --- oodeel/datasets/ooddataset.py | 10 +++++----- oodeel/datasets/torch_data_handler.py | 3 +++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/oodeel/datasets/ooddataset.py b/oodeel/datasets/ooddataset.py index abd3c431..21b46fc8 100644 --- a/oodeel/datasets/ooddataset.py +++ b/oodeel/datasets/ooddataset.py @@ -265,7 +265,7 @@ def prepare( with_ood_labels: bool = False, with_labels: bool = True, shuffle: bool = False, - shuffle_buffer_size: Optional[int] = None, + **kwargs_prepare, ) -> DatasetType: """Prepare self.data for scoring or training @@ -282,9 +282,9 @@ def prepare( Defaults to True. shuffle (bool, optional): To shuffle the returned dataset or not. Defaults to False. - shuffle_buffer_size (int, optional): (TF only) Size of the shuffle buffer. - If None, taken as the number of samples in the dataset. - Defaults to None. + kwargs_prepare (dict): Additional parameters to be passed to the + data_handler.prepare_for_training method. + Returns: DatasetType: prepared dataset @@ -323,7 +323,7 @@ def prepare( preprocess_fn=preprocess_fn, augment_fn=augment_fn, output_keys=keys, - shuffle_buffer_size=shuffle_buffer_size, + **kwargs_prepare, ) return dataset diff --git a/oodeel/datasets/torch_data_handler.py b/oodeel/datasets/torch_data_handler.py index e13ec2c5..ea3f4707 100644 --- a/oodeel/datasets/torch_data_handler.py +++ b/oodeel/datasets/torch_data_handler.py @@ -569,6 +569,7 @@ def prepare_for_training( output_keys: Optional[list] = None, dict_based_fns: bool = False, shuffle_buffer_size: Optional[int] = None, + num_workers: int = 8, ) -> DataLoader: """Prepare a DataLoader for training @@ -587,6 +588,7 @@ def prepare_for_training( shuffle_buffer_size (int, optional): Size of the shuffle buffer. Not used in torch because we only rely on Map-Style datasets. Still as argument for API consistency. Defaults to None. + num_workers (int, optional): Number of workers to use for the dataloader. Returns: DataLoader: dataloader @@ -621,6 +623,7 @@ def collate_fn(batch: List[dict]): batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, + num_workers=num_workers, ) return loader From 356525f7786f29ced54662c486cbfc4c492c072b Mon Sep 17 00:00:00 2001 From: Yannick Prudent Date: Mon, 25 Mar 2024 14:52:44 +0100 Subject: [PATCH 4/5] fix: gram default value for orders limited to power 5 Note: power order > 7 was raising errors because of inf values when benchmarking imagenet OpenOOD's implementation also uses [1, ..., 5] as a default value for the orders argument --- oodeel/methods/gram.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/oodeel/methods/gram.py b/oodeel/methods/gram.py index b4bb1707..1802a283 100644 --- a/oodeel/methods/gram.py +++ b/oodeel/methods/gram.py @@ -76,7 +76,7 @@ class Gram(OODBaseDetector): def __init__( self, - orders: List[int] = [i for i in range(1, 11)], + orders: List[int] = [i for i in range(1, 6)], quantile: float = 0.01, ): super().__init__() @@ -90,6 +90,7 @@ def _fit_to_dataset( self, fit_dataset: Union[TensorType, DatasetType], val_split: float = 0.2, + verbose: bool = False, ) -> None: """ Compute the quantiles of channelwise correlations for each layer, power of @@ -102,13 +103,19 @@ def _fit_to_dataset( construct the index with. val_split (float): The percentage of fit data to use as validation data for normalization. Default to 0.2. + verbose (bool): Whether to print information during the fitting process. + Default to False. """ self.postproc_fns = [ self._stat for i in range(len(self.feature_extractor.feature_layers_id)) ] + # fit_stats shape: [n_features, n_samples, n_orders, n_channels] fit_stats, info = self.feature_extractor.predict( - fit_dataset, postproc_fns=self.postproc_fns, return_labels=True + fit_dataset, + postproc_fns=self.postproc_fns, + return_labels=True, + verbose=verbose, ) labels = info["labels"] self._classes = np.sort(np.unique(self.op.convert_to_numpy(labels))) @@ -256,21 +263,25 @@ def _stat(self, feature_map: TensorType) -> TensorType: (fm_s[0], fm_s[-1], -1), ) else: + # batch, channel, spatial feature_map_p = self.op.reshape( feature_map_p, (fm_s[0], fm_s[1], -1) ) + # batch, channel, channel feature_map_p = self.op.matmul( feature_map_p, self.op.permute(feature_map_p, (0, 2, 1)) ) + # normalize the Gram matrix feature_map_p = self.op.sign(feature_map_p) * ( self.op.abs(feature_map_p) ** (1 / p) ) # get the lower triangular part of the matrix feature_map_p = self.op.tril(feature_map_p) - # directly sum row-wise (to limit computational burden) + # directly sum row-wise (to limit computational burden) -> batch, channel feature_map_p = self.op.sum(feature_map_p, dim=2) # stat.append(self.op.t(feature_map_p)) stat.append(feature_map_p) + # batch, n_orders, channel stat = self.op.stack(stat, 1) return stat From 035c8c04fd130301e5b7b3f8d1d98e07204c18b3 Mon Sep 17 00:00:00 2001 From: Yannick Prudent Date: Mon, 25 Mar 2024 15:27:13 +0100 Subject: [PATCH 5/5] fix: default nearest value of dknn set to 50 --- oodeel/methods/dknn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oodeel/methods/dknn.py b/oodeel/methods/dknn.py index 86d86853..983316fc 100644 --- a/oodeel/methods/dknn.py +++ b/oodeel/methods/dknn.py @@ -42,7 +42,7 @@ class DKNN(OODBaseDetector): def __init__( self, - nearest: int = 1, + nearest: int = 50, ): super().__init__()