Skip to content

Commit

Permalink
Merge pull request #93 from ENSTA-U2IS-AI/dev
Browse files Browse the repository at this point in the history
✨ Add LPBNN, Adaptive ECE, start supporting Depth estimation & Improve segmentation
  • Loading branch information
alafage authored May 29, 2024
2 parents 42fa423 + d22221c commit 7d7aec8
Show file tree
Hide file tree
Showing 154 changed files with 5,934 additions and 2,667 deletions.
28 changes: 11 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ _TorchUncertainty_ is a package designed to help you leverage [uncertainty quant

:construction: _TorchUncertainty_ is in early development :construction: - expect changes, but reach out and contribute if you are interested in the project! **Please raise an issue if you have any bugs or difficulties and join the [discord server](https://discord.gg/HMCawt5MJu).**

Our webpage and documentation is available here: [torch-uncertainty.github.io](https://torch-uncertainty.github.io).
:books: Our webpage and documentation is available here: [torch-uncertainty.github.io](https://torch-uncertainty.github.io). :books:

TorchUncertainty contains the *official implementations* of multiple papers from *major machine-learning and computer vision conferences* and was/will be featured in tutorials at **WACV 2024** and **ECCV 2024**.

---

Expand Down Expand Up @@ -47,7 +49,14 @@ We make a quickstart available at [torch-uncertainty.github.io/quickstart](https

## :books: Implemented methods

TorchUncertainty currently supports **Classification**, **probabilistic** and pointwise **Regression** and **Segmentation**.
TorchUncertainty currently supports **classification**, **probabilistic** and pointwise **regression**, **segmentation** and **pixelwise regression** (such as monocular depth estimation). It includes the official codes of the following papers:

- *A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors* - [ICLR 2024](https://arxiv.org/abs/2310.08287)
- *LP-BNN: Encoding the latent posterior of Bayesian Neural Networks for uncertainty quantification* - [IEEE TPAMI](https://arxiv.org/abs/2012.02818)
- *Packed-Ensembles for Efficient Uncertainty Estimation* - [ICLR 2023](https://arxiv.org/abs/2210.09184) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html)
- *MUAD: Multiple Uncertainties for Autonomous Driving, a benchmark for multiple uncertainty types and tasks* - [BMVC 2022](https://arxiv.org/abs/2203.01437)

We also provide the following methods:

### Baselines

Expand Down Expand Up @@ -86,18 +95,3 @@ Our documentation contains the following tutorials:
- [Deep Evidential Regression on a Toy Example](https://torch-uncertainty.github.io/auto_tutorials/tutorial_der_cubic.html)
- [Training a LeNet with Monte-Carlo Dropout](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html)
- [Training a LeNet with Deep Evidential Classification](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html)

## Other References

This package also contains the official implementation of Packed-Ensembles.

If you find the corresponding models interesting, please consider citing our [paper](https://arxiv.org/abs/2210.09184):

```text
@inproceedings{laurent2023packed,
title={Packed-Ensembles for Efficient Uncertainty Estimation},
author={Laurent, Olivier and Lafage, Adrien and Tartaglione, Enzo and Daniel, Geoffrey and Martinez, Jean-Marc and Bursuc, Andrei and Franchi, Gianni},
booktitle={ICLR},
year={2023}
}
```
2 changes: 0 additions & 2 deletions auto_tutorials_source/tutorial_corruptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
torchvision and matplotlib.
"""

import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Resize

from torchvision.utils import make_grid
import matplotlib.pyplot as plt

ds = CIFAR10("./data", train=False, download=True)
Expand Down
18 changes: 10 additions & 8 deletions auto_tutorials_source/tutorial_mc_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@
# %%
# 4. Gathering Everything and Training the Model
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# You can also save the results in a variable by saving the output of
# `trainer.test`.

trainer.fit(model=routine, datamodule=datamodule)
trainer.test(model=routine, datamodule=datamodule)
trainer.test(model=routine, datamodule=datamodule);

# %%
# 5. Wrapping the Model in a MCBatchNorm
Expand All @@ -88,10 +90,10 @@
# to highlight the effect of stochasticity on the predictions.

routine.model = MCBatchNorm(
routine.model, num_estimators=8, convert=True, mc_batch_size=4
routine.model, num_estimators=8, convert=True, mc_batch_size=16
)
routine.model.fit(datamodule.train)
routine.eval()
routine.eval();

# %%
# 6. Testing the Model
Expand All @@ -118,17 +120,17 @@ def imshow(img):
dataiter = iter(datamodule.val_dataloader())
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images[:4, ...]))
print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4)))

routine.eval()
logits = routine(images).reshape(8, 128, 10)

probs = torch.nn.functional.softmax(logits, dim=-1)
most_uncertain = sorted(probs.var(0).sum(-1).topk(4).indices)

# print images
imshow(torchvision.utils.make_grid(images[most_uncertain, ...]))
print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4)))

for j in sorted(probs.var(0).sum(-1).topk(4).indices):
for j in most_uncertain:
values, predicted = torch.max(probs[:, j], 1)
print(
f"Predicted digits for the image {j}: ",
Expand Down
12 changes: 5 additions & 7 deletions auto_tutorials_source/tutorial_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@
In this tutorial, we will need:
- torch for its objects
- the "calibration error" metric to compute the ECE and evaluate the top-label calibration
- TorchUncertainty's Calibration Error metric to compute to evaluate the top-label calibration with ECE and plot the reliability diagrams
- the CIFAR-100 datamodule to handle the data
- a ResNet 18 as starting model
- the temperature scaler to improve the top-label calibration
- a utility to download hf models easily
- the calibration plot to visualize the calibration.
- a utility function to download HF models easily
If you use the classification routine, the plots will be automatically available in the tensorboard logs.
If you use the classification routine, the plots will be automatically available in the tensorboard logs if you use the `log_plots` flag.
"""

from torch_uncertainty.datamodules import CIFAR100DataModule
from torch_uncertainty.metrics import CE
from torch_uncertainty.metrics import CalibrationError
from torch_uncertainty.models.resnet import resnet18
from torch_uncertainty.post_processing import TemperatureScaler
from torch_uncertainty.utils import load_hf
Expand Down Expand Up @@ -88,7 +86,7 @@
test_dataloader = DataLoader(test_dataset, batch_size=32)

# Initialize the ECE
ece = CE(task="multiclass", num_classes=100)
ece = CalibrationError(task="multiclass", num_classes=100)

# Iterate on the calibration dataloader
for sample, target in test_dataloader:
Expand Down
33 changes: 30 additions & 3 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ API Reference
Routines
--------

The routine are the main building blocks of the library. They define the framework
The routine are the main building blocks of the library. They define the framework
in which the models are trained and evaluated. They allow for easy computation of different
metrics crucial for uncertainty estimation in different contexts, namely classification, regression and segmentation.

Expand Down Expand Up @@ -42,10 +42,20 @@ Segmentation

SegmentationRoutine

Pixelwise Regression
^^^^^^^^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

PixelRegressionRoutine

Baselines
---------

TorchUncertainty provide lightning-based models that can be easily trained and evaluated.
TorchUncertainty provide lightning-based models that can be easily trained and evaluated.
These models inherit from the routines and are specifically designed to benchmark
different methods in similar settings, here with constant architectures.

Expand Down Expand Up @@ -85,8 +95,19 @@ Segmentation
:nosignatures:
:template: class.rst

DeepLabBaseline
SegFormerBaseline

Monocular Depth Estimation
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

BTSBaseline

Layers
------

Expand Down Expand Up @@ -122,6 +143,8 @@ Bayesian layers
BayesConv1d
BayesConv2d
BayesConv3d
LPBNNLinear
LPBNNConv2d

Models
------
Expand Down Expand Up @@ -158,9 +181,12 @@ Metrics
:template: class.rst

AUSE
AURC
AdaptiveCalibrationError
BrierScore
CategoricalNLL
CE
CalibrationError
CovAt5Risk,
Disagreement
DistributionNLL
Entropy
Expand All @@ -169,6 +195,7 @@ Metrics
MeanGTRelativeAbsoluteError
MeanGTRelativeSquaredError
MutualInformation
RiskAt80Cov,
SILog
ThresholdAccuracy

Expand Down
9 changes: 4 additions & 5 deletions docs/source/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ The scope of TorchUncertainty
TorchUncertainty can host any method - if possible linked to a paper - and
roughly contained in the following fields:

* Uncertainty quantification in general, including Bayesian deep learning,
Monte Carlo dropout, ensemble methods, etc.
* Uncertainty quantification in general, including Bayesian deep learning, Monte Carlo dropout, ensemble methods, etc.
* Out-of-distribution detection methods
* Applications (e.g. object detection, segmentation, etc.)

Expand Down Expand Up @@ -54,7 +53,7 @@ group:
Then navigate to ``./docs`` and build the documentation with:

.. parsed-literal::
make html
Optionally, specify ``html-noplot`` instead of ``html`` to avoid running the tutorials.
Expand All @@ -73,7 +72,7 @@ PR. This will avoid multiplying the number featureless commits. To do this,
run, at the root of the folder:

.. parsed-literal::
python3 -m pytest tests
Try to include an emoji at the start of each commit message following the suggestions
Expand Down Expand Up @@ -118,4 +117,4 @@ License

If you feel that the current license is an obstacle to your contribution, let
us know, and we may reconsider. However, the models’ weights hosted on Hugging
Face are likely to stay Apache 2.0.
Face are likely to remain Apache 2.0.
36 changes: 30 additions & 6 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ Welcome to Torch Uncertainty

Welcome to the documentation of TorchUncertainty.

This website contains the documentation for
This website contains the documentation for
`installing <https://torch-uncertainty.github.io/installation.html>`_
and `contributing <https://torch-uncertainty.github.io/>`_ to TorchUncertainty,
details on the `API <https://torch-uncertainty.github.io/api.html>`_, and a
and `contributing <https://torch-uncertainty.github.io/>`_ to TorchUncertainty,
details on the `API <https://torch-uncertainty.github.io/api.html>`_, and a
`comprehensive list of the references <https://torch-uncertainty.github.io/references.html>`_ of
the models and metrics implemented.

Expand All @@ -29,12 +29,36 @@ Installation
To install TorchUncertainty with contribution in mind, check the
`contribution page <https://torch-uncertainty.github.io/contributing.html>`_.

-----

Official Implementations
^^^^^^^^^^^^^^^^^^^^^^^^

TorchUncertainty also houses multiple official implementations of papers from major conferences & journals.

**A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors**

* Authors: *Olivier Laurent, Emanuel Aldea, and Gianni Franchi*
* Paper: `ICLR 2024 <https://arxiv.org/abs/2310.08287>`_.

**Encoding the latent posterior of Bayesian Neural Networks for uncertainty quantification**

* Authors: *Gianni Franchi, Andrei Bursuc, Emanuel Aldea, Severine Dubuisson, and Isabelle Bloch*
* Paper: `IEEE TPAMI <https://arxiv.org/abs/2012.02818>`_.

**Packed-Ensembles for Efficient Uncertainty Estimation**

* Authors: *Olivier Laurent, Adrien Lafage, Enzo Tartaglione, Geoffrey Daniel, Jean-Marc Martinez, Andrei Bursuc, and Gianni Franchi*
* Paper: `ICLR 2023 <https://arxiv.org/abs/2210.09184>`_.

**MUAD: Multiple Uncertainties for Autonomous Driving, a benchmark for multiple uncertainty types and tasks**

* Authors: *Gianni Franchi, Xuanlong Yu, Andrei Bursuc, Angel Tena, Rémi Kazmierczak, Séverine Dubuisson, Emanuel Aldea, David Filliat*
* Paper: `BMVC 2022 <https://arxiv.org/abs/2203.01437>`_.

Packed-Ensembles
^^^^^^^^^^^^^^^^

Finally, TorchUncertainty also includes the official PyTorch implementation for
the following paper:

**Packed-Ensembles for Efficient Uncertainty Estimation**

* Authors: *Olivier Laurent, Adrien Lafage, Enzo Tartaglione, Geoffrey Daniel, Jean-Marc Martinez, Andrei Bursuc, and Gianni Franchi*
Expand Down
4 changes: 2 additions & 2 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ From PyPI
---------

Check that you have PyTorch (cpu or gpu) installed on your system. Then, install
the package via pip:
the package via pip:

.. parsed-literal::
Expand All @@ -24,7 +24,7 @@ To update the package, run:

.. parsed-literal::
pip install -U torch-uncertainty
pip install -U torch-uncertainty
From source
-----------
Expand Down
Loading

0 comments on commit 7d7aec8

Please sign in to comment.