Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/methods #85

Merged
merged 5 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions oodeel/datasets/ooddataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def prepare(
with_ood_labels: bool = False,
with_labels: bool = True,
shuffle: bool = False,
shuffle_buffer_size: Optional[int] = None,
**kwargs_prepare,
) -> DatasetType:
"""Prepare self.data for scoring or training

Expand All @@ -282,9 +282,9 @@ def prepare(
Defaults to True.
shuffle (bool, optional): To shuffle the returned dataset or not.
Defaults to False.
shuffle_buffer_size (int, optional): (TF only) Size of the shuffle buffer.
If None, taken as the number of samples in the dataset.
Defaults to None.
kwargs_prepare (dict): Additional parameters to be passed to the
data_handler.prepare_for_training method.
Comment on lines +285 to +286
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a suggestion: we could extend this docstring with examples of usage, like "e.g. shuffle_buffer_size for TF data handler or num_workers for torch handler."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point



Returns:
DatasetType: prepared dataset
Expand Down Expand Up @@ -323,7 +323,7 @@ def prepare(
preprocess_fn=preprocess_fn,
augment_fn=augment_fn,
output_keys=keys,
shuffle_buffer_size=shuffle_buffer_size,
**kwargs_prepare,
)

return dataset
3 changes: 3 additions & 0 deletions oodeel/datasets/torch_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ def prepare_for_training(
output_keys: Optional[list] = None,
dict_based_fns: bool = False,
shuffle_buffer_size: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this line must be removed, since there is now kwargs_prepare?

num_workers: int = 8,
) -> DataLoader:
"""Prepare a DataLoader for training

Expand All @@ -587,6 +588,7 @@ def prepare_for_training(
shuffle_buffer_size (int, optional): Size of the shuffle buffer. Not used
in torch because we only rely on Map-Style datasets. Still as argument
for API consistency. Defaults to None.
num_workers (int, optional): Number of workers to use for the dataloader.

Returns:
DataLoader: dataloader
Expand Down Expand Up @@ -621,6 +623,7 @@ def collate_fn(batch: List[dict]):
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn,
num_workers=num_workers,
)
return loader

Expand Down
2 changes: 2 additions & 0 deletions oodeel/extractor/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def predict(
self,
dataset: Union[ItemType, DatasetType],
postproc_fns: Optional[List[Callable]] = None,
verbose: bool = False,
**kwargs,
) -> Tuple[List[TensorType], dict]:
"""Get the projection of the dataset in the feature space of self.model
Expand All @@ -122,6 +123,7 @@ def predict(
dataset (Union[ItemType, DatasetType]): input dataset
postproc_fns (Optional[Callable]): postprocessing function to apply to each
feature immediately after forward. Default to None.
verbose (bool): if True, display a progress bar. Defaults to False.
kwargs (dict): additional arguments not considered for prediction

Returns:
Expand Down
5 changes: 4 additions & 1 deletion oodeel/extractor/keras_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Optional

import tensorflow as tf
from tqdm import tqdm

from ..datasets.tf_data_handler import TFDataHandler
from ..types import Callable
Expand Down Expand Up @@ -190,6 +191,7 @@ def predict(
self,
dataset: Union[ItemType, tf.data.Dataset],
postproc_fns: Optional[List[Callable]] = None,
verbose: bool = False,
**kwargs,
) -> Tuple[List[tf.Tensor], dict]:
"""Get the projection of the dataset in the feature space of self.model
Expand All @@ -198,6 +200,7 @@ def predict(
dataset (Union[ItemType, tf.data.Dataset]): input dataset
postproc_fns (Optional[Callable]): postprocessing function to apply to each
feature immediately after forward. Default to None.
verbose (bool): if True, display a progress bar. Defaults to False.
kwargs (dict): additional arguments not considered for prediction

Returns:
Expand All @@ -218,7 +221,7 @@ def predict(
features = [None for i in range(len(self.feature_layers_id))]
logits = None
contains_labels = TFDataHandler.get_item_length(dataset) > 1
for elem in dataset:
for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
tensor = TFDataHandler.get_input_from_dataset_item(elem)
features_batch, logits_batch = self.predict_tensor(tensor, postproc_fns)

Expand Down
5 changes: 4 additions & 1 deletion oodeel/extractor/torch_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from ..datasets.torch_data_handler import TorchDataHandler
from ..types import Callable
Expand Down Expand Up @@ -226,6 +227,7 @@ def predict(
dataset: Union[DataLoader, ItemType],
postproc_fns: Optional[List[Callable]] = None,
detach: bool = True,
verbose: bool = False,
**kwargs,
) -> Tuple[List[torch.Tensor], dict]:
"""Get the projection of the dataset in the feature space of self.model
Expand All @@ -236,6 +238,7 @@ def predict(
each feature immediately after forward. Default to None.
detach (bool): if True, return features detached from the computational
graph. Defaults to True.
verbose (bool): if True, display a progress bar. Defaults to False.
kwargs (dict): additional arguments not considered for prediction

Returns:
Expand All @@ -257,7 +260,7 @@ def predict(
logits = None
batch = next(iter(dataset))
contains_labels = isinstance(batch, (list, tuple)) and len(batch) > 1
for elem in dataset:
for elem in tqdm(dataset, desc="Predicting", disable=not verbose):
tensor = TorchDataHandler.get_input_from_dataset_item(elem)
features_batch, logits_batch = self.predict_tensor(
tensor, postproc_fns, detach=detach
Expand Down
20 changes: 16 additions & 4 deletions oodeel/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import inspect
from abc import ABC
from abc import abstractmethod
from typing import get_args

import numpy as np
from tqdm import tqdm

from ..extractor.feature_extractor import FeatureExtractor
from ..types import Callable
Expand Down Expand Up @@ -104,6 +106,7 @@ def fit(
fit_dataset: Optional[Union[ItemType, DatasetType]] = None,
feature_layers_id: List[Union[int, str]] = [],
input_layer_id: Optional[Union[int, str]] = None,
verbose: bool = False,
**kwargs,
) -> None:
"""Prepare the detector for scoring:
Expand All @@ -122,6 +125,7 @@ def fit(
layer of the feature extractor.
If int, the rank of the layer in the layer list
If str, the name of the layer. Defaults to None.
verbose (bool): if True, display a progress bar. Defaults to False.
"""
(
self.backend,
Expand All @@ -144,7 +148,7 @@ def fit(
" provided to compute react activation threshold"
)
else:
self.compute_react_threshold(model, fit_dataset)
self.compute_react_threshold(model, fit_dataset, verbose=verbose)

if (feature_layers_id == []) and (self.requires_internal_features):
raise ValueError(
Expand All @@ -160,6 +164,8 @@ def fit(
)

if fit_dataset is not None:
if "verbose" in inspect.signature(self._fit_to_dataset).parameters.keys():
kwargs.update({"verbose": verbose})
Comment on lines +167 to +168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what is the best practice but I think it's more clear to make verbose a real argument of _fit_to_dataset(), not in kwargs. Otherwise, we should make it clear in the docstring of _fit_to_dataset() (and in the tutorial "implement your own baseline") that verbose is a supported argument.

self._fit_to_dataset(fit_dataset, **kwargs)

def _load_feature_extractor(
Expand Down Expand Up @@ -207,12 +213,14 @@ def _fit_to_dataset(self, fit_dataset: DatasetType) -> None:
def score(
self,
dataset: Union[ItemType, DatasetType],
verbose: bool = False,
) -> np.ndarray:
"""
Computes an OOD score for input samples "inputs".

Args:
dataset (Union[ItemType, DatasetType]): dataset or tensors to score
verbose (bool): if True, display a progress bar. Defaults to False.

Returns:
tuple: scores or list of scores (depending on the input) and a dictionary
Expand All @@ -236,7 +244,7 @@ def score(
scores = np.array([])
logits = None

for item in dataset:
for item in tqdm(dataset, desc="Scoring", disable=not verbose):
tensor = self.data_handler.get_input_from_dataset_item(item)
score_batch = self._score_tensor(tensor)
logits_batch = self.op.convert_to_numpy(
Expand Down Expand Up @@ -267,9 +275,13 @@ def score(
info = dict(labels=labels, logits=logits)
return scores, info

def compute_react_threshold(self, model: Callable, fit_dataset: DatasetType):
def compute_react_threshold(
self, model: Callable, fit_dataset: DatasetType, verbose: bool = False
):
penult_feat_extractor = self._load_feature_extractor(model, [-2])
unclipped_features, _ = penult_feat_extractor.predict(fit_dataset)
unclipped_features, _ = penult_feat_extractor.predict(
fit_dataset, verbose=verbose
)
self.react_threshold = self.op.quantile(
unclipped_features[0], self.react_quantile
)
Expand Down
2 changes: 1 addition & 1 deletion oodeel/methods/dknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class DKNN(OODBaseDetector):

def __init__(
self,
nearest: int = 1,
nearest: int = 50,
):
super().__init__()

Expand Down
17 changes: 14 additions & 3 deletions oodeel/methods/gram.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class Gram(OODBaseDetector):

def __init__(
self,
orders: List[int] = [i for i in range(1, 11)],
orders: List[int] = [i for i in range(1, 6)],
quantile: float = 0.01,
):
super().__init__()
Expand All @@ -90,6 +90,7 @@ def _fit_to_dataset(
self,
fit_dataset: Union[TensorType, DatasetType],
val_split: float = 0.2,
verbose: bool = False,
) -> None:
"""
Compute the quantiles of channelwise correlations for each layer, power of
Expand All @@ -102,13 +103,19 @@ def _fit_to_dataset(
construct the index with.
val_split (float): The percentage of fit data to use as validation data for
normalization. Default to 0.2.
verbose (bool): Whether to print information during the fitting process.
Default to False.
"""
self.postproc_fns = [
self._stat for i in range(len(self.feature_extractor.feature_layers_id))
]

# fit_stats shape: [n_features, n_samples, n_orders, n_channels]
fit_stats, info = self.feature_extractor.predict(
fit_dataset, postproc_fns=self.postproc_fns, return_labels=True
fit_dataset,
postproc_fns=self.postproc_fns,
return_labels=True,
verbose=verbose,
)
labels = info["labels"]
self._classes = np.sort(np.unique(self.op.convert_to_numpy(labels)))
Expand Down Expand Up @@ -256,21 +263,25 @@ def _stat(self, feature_map: TensorType) -> TensorType:
(fm_s[0], fm_s[-1], -1),
)
else:
# batch, channel, spatial
feature_map_p = self.op.reshape(
feature_map_p, (fm_s[0], fm_s[1], -1)
)
# batch, channel, channel
feature_map_p = self.op.matmul(
feature_map_p, self.op.permute(feature_map_p, (0, 2, 1))
)
# normalize the Gram matrix
feature_map_p = self.op.sign(feature_map_p) * (
self.op.abs(feature_map_p) ** (1 / p)
)
# get the lower triangular part of the matrix
feature_map_p = self.op.tril(feature_map_p)
# directly sum row-wise (to limit computational burden)
# directly sum row-wise (to limit computational burden) -> batch, channel
feature_map_p = self.op.sum(feature_map_p, dim=2)
# stat.append(self.op.t(feature_map_p))
stat.append(feature_map_p)
# batch, n_orders, channel
stat = self.op.stack(stat, 1)
return stat

Expand Down
16 changes: 9 additions & 7 deletions oodeel/methods/mahalanobis.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,32 @@ def _fit_to_dataset(self, fit_dataset: DatasetType) -> None:
fit_dataset (Union[TensorType, DatasetType]): input dataset (ID)
"""
# extract features and labels
features, infos = self.feature_extractor.predict(fit_dataset)
features, infos = self.feature_extractor.predict(fit_dataset, detach=True)
labels = infos["labels"]

# unique sorted classes
self._classes = np.sort(np.unique(self.op.convert_to_numpy(labels)))

# compute mus and covs
mus = dict()
covs = dict()
mean_cov = None
for cls in self._classes:
indexes = self.op.equal(labels, cls)
_features_cls = self.op.flatten(features[0][indexes])
mus[cls] = self.op.mean(_features_cls, dim=0)
_zero_f_cls = _features_cls - mus[cls]
covs[cls] = (
cov_cls = (
self.op.matmul(self.op.t(_zero_f_cls), _zero_f_cls)
/ _zero_f_cls.shape[0]
)
if mean_cov is None:
mean_cov = (len(_features_cls) / len(features)) * cov_cls
else:
mean_cov += (len(_features_cls) / len(features)) * cov_cls

# mean cov and its inverse
mean_cov = self.op.mean(self.op.stack(list(covs.values())), dim=0)

self._mus = mus
# pseudo-inverse of the mean covariance matrix
self._pinv_cov = self.op.pinv(mean_cov)
self._mus = mus

def _score_tensor(self, inputs: TensorType) -> Tuple[np.ndarray]:
"""
Expand Down
Loading