Skip to content

Commit

Permalink
Merge pull request #89 from deel-ai/rmds
Browse files Browse the repository at this point in the history
RMDS (Relative Mahalanobis Distance), a new OOD detector
  • Loading branch information
cofri authored Apr 22, 2024
2 parents 1110a01 + 3cb6de7 commit e3784c3
Show file tree
Hide file tree
Showing 10 changed files with 1,305 additions and 18 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ __pycache__
oodeel_dev_env
oodeel_env*
*_env
.venv
._tf
.venv_tf

# Files generated:
logs
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ Currently, **oodeel** includes the following baselines:
| NMD | [Neural Mean Discrepancy for Efficient Out-of-Distribution Detection](https://openaccess.thecvf.com/content/CVPR2022/html/Dong_Neural_Mean_Discrepancy_for_Efficient_Out-of-Distribution_Detection_CVPR_2022_paper.html) | CVPR 2022 | planned |
| Gram | [Detecting Out-of-Distribution Examples with Gram Matrices](https://proceedings.mlr.press/v119/sastry20a.html) | ICML 2020 | avail [tensorflow](docs/notebooks/tensorflow/demo_gram_tf.ipynb) or [torch](docs/notebooks/torch/demo_gram_torch.ipynb) |
| GEN | [GEN: Pushing the Limits of Softmax-Based Out-of-Distribution Detection](https://openaccess.thecvf.com/content/CVPR2023/html/Liu_GEN_Pushing_the_Limits_of_Softmax-Based_Out-of-Distribution_Detection_CVPR_2023_paper.html) | CVPR 2023 | avail [tensorflow](docs/notebooks/tensorflow/demo_gen_tf.ipynb) or [torch](docs/notebooks/torch/demo_gen_torch.ipynb) |
| RMDS | [A Simple Fix to Mahalanobis Distance for Improving Near-OOD Detection](https://arxiv.org/abs/2106.09022) | preprint | avail [tensorflow](docs/notebooks/tensorflow/demo_rmds_tf.ipynb) or [torch](docs/notebooks/torch/demo_rmds_torch.ipynb) |



Expand Down
493 changes: 493 additions & 0 deletions docs/notebooks/tensorflow/demo_rmds_tf.ipynb

Large diffs are not rendered by default.

98 changes: 80 additions & 18 deletions docs/notebooks/torch/demo_mahalanobis_torch.ipynb

Large diffs are not rendered by default.

517 changes: 517 additions & 0 deletions docs/notebooks/torch/demo_rmds_torch.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ nav:
- React: notebooks/tensorflow/demo_react_tf.ipynb
- Gram: notebooks/tensorflow/demo_gram_tf.ipynb
- GEN: notebooks/tensorflow/demo_gen_tf.ipynb
- RMDS: notebooks/tensorflow/demo_rmds_tf.ipynb
- OOD Baselines (Torch):
- MLS/MSP: notebooks/torch/demo_mls_msp_torch.ipynb
- ODIN: notebooks/torch/demo_odin_torch.ipynb
Expand All @@ -28,6 +29,7 @@ nav:
- React: notebooks/torch/demo_react_torch.ipynb
- Gram: notebooks/torch/demo_gram_torch.ipynb
- GEN: notebooks/torch/demo_gen_torch.ipynb
- RMDS: notebooks/torch/demo_rmds_torch.ipynb
- Advanced Topics:
- Seamlessly handling torch and tf datasets with DataHandler: pages/datahandler_tuto.md
- Seamlessly handling torch and tf Tensors with Operator: pages/operator_tuto.md
Expand Down
2 changes: 2 additions & 0 deletions oodeel/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .mahalanobis import Mahalanobis
from .mls import MLS
from .odin import ODIN
from .rmds import RMDS
from .vim import VIM

__all__ = [
Expand All @@ -39,5 +40,6 @@
"Mahalanobis",
"MLS",
"ODIN",
"RMDS",
"VIM",
]
122 changes: 122 additions & 0 deletions oodeel/methods/rmds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import numpy as np

from ..types import DatasetType
from ..types import TensorType
from ..types import Tuple
from oodeel.methods.mahalanobis import Mahalanobis


class RMDS(Mahalanobis):
"""
"A Simple Fix to Mahalanobis Distance for Improving Near-OOD Detection"
https://arxiv.org/abs/2106.09022
Args:
eps (float): magnitude for gradient based input perturbation.
Defaults to 0.02.
"""

def __init__(self, eps: float = 0.002):
super().__init__(eps=eps)

def _fit_to_dataset(self, fit_dataset: DatasetType) -> None:
"""
Constructs the per class means and the covariance matrix,
as well as the background mean and covariance matrix,
from ID data "fit_dataset".
The means and pseudo-inverses of the covariance matrices
will be used for RMDS score computation.
Args:
fit_dataset (Union[TensorType, DatasetType]): input dataset (ID)
"""
# means and pseudo-inverse of the mean convariance matrix from Mahalanobis
super()._fit_to_dataset(fit_dataset)

# extract features
features, _ = self.feature_extractor.predict(fit_dataset)

# compute background mu and cov
_features_bg = self.op.flatten(features[0])
mu_bg = self.op.mean(_features_bg, dim=0)
_zero_f_bg = _features_bg - mu_bg
cov_bg = self.op.matmul(self.op.t(_zero_f_bg), _zero_f_bg) / _zero_f_bg.shape[0]

# background mu and pseudo-inverse of the mean covariance matrices
self._mu_bg = mu_bg
self._pinv_cov_bg = self.op.pinv(cov_bg)

def _score_tensor(self, inputs: TensorType) -> Tuple[np.ndarray]:
"""
Computes an OOD score for input samples "inputs" based on the RMDS
distance with respect to the closest class-conditional Gaussian distribution,
and the background distribution.
Args:
inputs (TensorType): input samples
Returns:
Tuple[np.ndarray]: scores, logits
"""
# input preprocessing (perturbation)
if self.eps > 0:
inputs_p = self._input_perturbation(inputs)
else:
inputs_p = inputs

# mahalanobis score on perturbed inputs
features_p, _ = self.feature_extractor.predict_tensor(inputs_p)
features_p = self.op.flatten(features_p[0])
gaussian_score_p = self._mahalanobis_score(features_p)

# background score on perturbed inputs
gaussian_score_bg = self._background_score(features_p)

# take the highest score for each sample
gaussian_score_corrected = self.op.max(
gaussian_score_bg - gaussian_score_p, dim=1
)
return -self.op.convert_to_numpy(gaussian_score_corrected)

def _background_score(self, out_features: TensorType) -> TensorType:
"""
Mahalanobis distance-based background score. For each test sample, it computes
the log of the probability densities of some observations (assuming a
normal distribution) using the mahalanobis distance with respect to the
background distribution.
Args:
out_features (TensorType): test samples features
Returns:
TensorType: confidence scores (with respect to the background distribution)
"""
zero_f = out_features - self._mu_bg
# gaussian log prob density (mahalanobis)
log_probs_f = -0.5 * self.op.diag(
self.op.matmul(self.op.matmul(zero_f, self._pinv_cov_bg), self.op.t(zero_f))
)
gaussian_score = self.op.reshape(log_probs_f, (-1, 1))
return gaussian_score
43 changes: 43 additions & 0 deletions tests/tests_tensorflow/methods/test_tf_rmds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import pytest

from oodeel.methods import RMDS
from tests.tests_tensorflow import eval_detector_on_blobs


@pytest.mark.parametrize("auroc_thr,fpr95_thr", [(0.95, 0.05)])
def test_rmds(auroc_thr, fpr95_thr):
"""
Test RMDS on toy blobs OOD dataset-wise task
We check that the area under ROC is above a certain threshold, and that the FPR95TPR
is below an other threshold.
"""
rmds = RMDS()
eval_detector_on_blobs(
detector=rmds,
auroc_thr=auroc_thr,
fpr95_thr=fpr95_thr,
batch_size=64,
)
42 changes: 42 additions & 0 deletions tests/tests_torch/methods/test_torch_rmds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import pytest

from oodeel.methods import RMDS
from tests.tests_torch import eval_detector_on_blobs


@pytest.mark.parametrize("auroc_thr,fpr95_thr", [(0.95, 0.05)])
def test_rmds(auroc_thr, fpr95_thr):
"""
Test RMDS on toy blobs OOD dataset-wise task
We check that the area under ROC is above a certain threshold, and that the FPR95TPR
is below an other threshold.
"""
rmds = RMDS()
eval_detector_on_blobs(
detector=rmds,
auroc_thr=auroc_thr,
fpr95_thr=fpr95_thr,
)

0 comments on commit e3784c3

Please sign in to comment.