Skip to content

Commit

Permalink
📖 Refine documentation
Browse files Browse the repository at this point in the history
- Update Tutorials
Co-authored-by: Olivier Laurent <[email protected]>
  • Loading branch information
alafage committed Apr 2, 2024
1 parent b13fb0f commit 9547fac
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 121 deletions.
165 changes: 84 additions & 81 deletions auto_tutorials_source/tutorial_corruptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
torch_uncertainty.transforms.corruptions. We also need to load utilities from
torchvision and matplotlib.
"""

import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Resize
Expand All @@ -20,95 +21,97 @@

ds = CIFAR10("./data", train=False, download=True)

def get_images(main_transform, severity):
ds_transforms = Compose([ToTensor(), main_transform(severity), Resize(256)])
ds = CIFAR10("./data", train=False, download=False, transform=ds_transforms)
return make_grid([ds[i][0] for i in range(6)]).permute(1, 2, 0)

def show_images(transform):
print("Original Images")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(transform, 0))
plt.show()

for severity in range(1, 6):
print(f"Severity {severity}")
with torch.no_grad():
plt.axis('off')
plt.imshow(get_images(transform, severity))
plt.show()

# %%
# 1. Gaussian Noise
# ~~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import GaussianNoise
def get_images(main_corruption, index: int = 0):
"""Create an image showing the 6 levels of corruption of a given transform."""
images = []
for severity in range(6):
ds_transforms = Compose(
[ToTensor(), main_corruption(severity), Resize(256, antialias=True)]
)
ds = CIFAR10("./data", train=False, download=False, transform=ds_transforms)
images.append(ds[index][0].permute(1, 2, 0).numpy())
return images


def show_images(transforms):
"""Show the effect of all given transforms."""
num_corruptions = len(transforms)
_, ax = plt.subplots(num_corruptions, 6, figsize=(10, int(1.5 * num_corruptions)))
for i, transform in enumerate(transforms):
images = get_images(transform, index=i)
ax[i][0].text(
-0.1,
0.5,
transform.__name__,
transform=ax[i][0].transAxes,
rotation="vertical",
horizontalalignment="right",
verticalalignment="center",
fontsize=12,
)
for j in range(6):
ax[i][j].imshow(images[j])
if i == 0 and j == 0:
ax[i][j].set_title("Original")
elif i == 0:
ax[i][j].set_title(f"Severity {j}")
ax[i][j].axis("off")
plt.show()

show_images(GaussianNoise)

# %%
# 2. Shot Noise
# ~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import ShotNoise

show_images(ShotNoise)
# 1. Noise Corruptions
# ~~~~~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import (
GaussianNoise,
ShotNoise,
ImpulseNoise,
SpeckleNoise,
)

show_images(
[
GaussianNoise,
ShotNoise,
ImpulseNoise,
SpeckleNoise,
]
)

# %%
# 3. Impulse Noise
# ~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import ImpulseNoise

show_images(ImpulseNoise)
# 2. Blur Corruptions
# ~~~~~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import (
GaussianBlur,
GlassBlur,
DefocusBlur,
)

show_images(
[
GaussianBlur,
GlassBlur,
DefocusBlur,
]
)

# %%
# 4. Speckle Noise
# ~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import SpeckleNoise

show_images(SpeckleNoise)

# %%
# 5. Gaussian Blur
# ~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import GaussianBlur

show_images(GaussianBlur)

# %%
# 6. Glass Blur
# ~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import GlassBlur

show_images(GlassBlur)

# %%
# 7. Defocus Blur
# ~~~~~~~~~~~~~~~

from torch_uncertainty.transforms.corruptions import DefocusBlur

show_images(DefocusBlur)

#%%
# 8. JPEG Compression
# ~~~~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import JPEGCompression

show_images(JPEGCompression)

#%%
# 9. Pixelate
# ~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import Pixelate

show_images(Pixelate)

#%%
# 10. Frost
# ~~~~~~~~~
from torch_uncertainty.transforms.corruptions import Frost

show_images(Frost)
# 3. Other Corruptions
# ~~~~~~~~~~~~~~~~~~~~
from torch_uncertainty.transforms.corruptions import (
JPEGCompression,
Pixelate,
Frost,
)

show_images(
[
JPEGCompression,
Pixelate,
Frost,
]
)

# %%
# Reference
Expand Down
6 changes: 3 additions & 3 deletions auto_tutorials_source/tutorial_mc_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
First, we have to load the following utilities from TorchUncertainty:
- the Trainer from Lightning
- the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule
- the datamodule handling dataloaders: MNISTDataModule from torch_uncertainty.datamodules
- the model: LeNet, which lies in torch_uncertainty.models
- the mc-batch-norm wrapper: mc_dropout, which lies in torch_uncertainty.models
- the classification training routine in the torch_uncertainty.training.classification module
- the MC Batch Normalization wrapper: mc_batch_norm, which lies in torch_uncertainty.post_processing
- the classification training routine in the torch_uncertainty.routines
- an optimization recipe in the torch_uncertainty.optim_recipes module.
We also need import the neural network utils within `torch.nn`.
Expand Down
23 changes: 12 additions & 11 deletions auto_tutorials_source/tutorial_mc_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
For more information on Monte-Carlo Dropout, we refer the reader to the following resources:
- Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning `ICML 2016 <https://browse.arxiv.org/pdf/1506.02142.pdf>`_
- What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision? `NeurIPS 2017 <https://browse.arxiv.org/pdf/1703.04977.pdf>`_
- Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning `PMLR 2016 <https://browse.arxiv.org/pdf/1506.02142.pdf>`_
Training a LeNet with MC Dropout using TorchUncertainty models and PyTorch Lightning
-------------------------------------------------------------------------------------
Expand All @@ -20,14 +20,15 @@
First, we have to load the following utilities from TorchUncertainty:
- the Trainer from Lightning
- the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule
- the datamodule handling dataloaders: MNISTDataModule from torch_uncertainty.datamodules
- the model: LeNet, which lies in torch_uncertainty.models
- the mc-dropout wrapper: mc_dropout, which lies in torch_uncertainty.models
- the classification training routine in the torch_uncertainty.training.classification module
- the MC Dropout wrapper: mc_dropout, which lies in torch_uncertainty.models
- the classification training routine in the torch_uncertainty.routines
- an optimization recipe in the torch_uncertainty.optim_recipes module.
We also need import the neural network utils within `torch.nn`.
"""

# %%
from pathlib import Path

Expand All @@ -48,22 +49,22 @@
# logs, and to fake-parse the arguments needed for using the PyTorch Lightning
# Trainer. We also create the datamodule that handles the MNIST dataset,
# dataloaders and transforms. We create the model using the
# blueprint from torch_uncertainty.models and we wrap it into mc-dropout.
# blueprint from torch_uncertainty.models and we wrap it into mc_dropout.
#
# It is important to specify the arguments ``version`` as ``mc-dropout``,
# ``num_estimators`` and the ``dropout_rate`` to use Monte Carlo dropout.
# It is important to specify the arguments,``num_estimators`` and the ``dropout_rate``
# to use Monte Carlo dropout.

trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False)

# datamodule
root = Path("") / "data"
root = Path("") / "data"
datamodule = MNISTDataModule(root=root, batch_size=128)


model = lenet(
in_channels=datamodule.num_channels,
num_classes=datamodule.num_classes,
dropout_rate=0.6,
dropout_rate=0.4,
)

mc_model = mc_dropout(model, num_estimators=16, last_layer=False)
Expand All @@ -84,7 +85,6 @@
loss=nn.CrossEntropyLoss(),
optim_recipe=optim_cifar10_resnet18(mc_model),
num_estimators=16,

)

# %%
Expand Down Expand Up @@ -134,5 +134,6 @@ def imshow(img):
" ".join([str(image_id.item()) for image_id in predicted]),
)

# %% We see that there is some disagreement between the samples of the dropout
# %%
# We see that there is some disagreement between the samples of the dropout
# approximation of the posterior distribution.
11 changes: 8 additions & 3 deletions auto_tutorials_source/tutorial_pe_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
.. figure:: /_static/img/cifar10.png
:alt: cifar10
:figclass: figure-caption
cifar10
Sample of the CIFAR-10 dataset
Training an image Packed-Ensemble classifier
--------------------------------------------
Expand All @@ -39,6 +40,8 @@
1. Load and normalize CIFAR10
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""

# %%
import torch
import torchvision
import torchvision.transforms as transforms
Expand Down Expand Up @@ -103,6 +106,8 @@
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.figure(figsize=(10, 3))
plt.axis("off")
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()

Expand All @@ -112,7 +117,7 @@ def imshow(img):
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))
imshow(torchvision.utils.make_grid(images, pad_value=1))
# print labels
print(" ".join(f"{classes[labels[j]]:5s}" for j in range(batch_size)))

Expand Down Expand Up @@ -244,7 +249,7 @@ def forward(self, x):
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images))
imshow(torchvision.utils.make_grid(images, pad_value=1))
print(
"GroundTruth: ",
" ".join(f"{classes[labels[j]]:5s}" for j in range(batch_size)),
Expand Down
Loading

0 comments on commit 9547fac

Please sign in to comment.