Skip to content

Commit

Permalink
Merge pull request #115 from ENSTA-U2IS-AI/dev
Browse files Browse the repository at this point in the history
🐛 Fix OOD & Post Processing at the same time & other small changes
  • Loading branch information
o-laurent authored Sep 29, 2024
2 parents 58e5ad1 + dccedc7 commit 7df97be
Show file tree
Hide file tree
Showing 16 changed files with 124 additions and 49 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ docs/*/auto_tutorials/
*.ckpt
*.out
docs/source/sg_execution_times.rst
test**/*.csv
test
**/*.csv
pyrightconfig.json

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent"
)
author = "Adrien Lafage and Olivier Laurent"
release = "0.2.2.post0"
release = "0.2.2.post1"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "torch_uncertainty"
version = "0.2.2.post0"
version = "0.2.2.post1"
authors = [
{ name = "ENSTA U2IS", email = "[email protected]" },
{ name = "Adrien Lafage", email = "[email protected]" },
Expand Down
4 changes: 0 additions & 4 deletions torch_uncertainty/metrics/classification/risk_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ def plot(
ax.set_ylabel("Risk - Error Rate (%)", fontsize=16)
ax.set_xlim(0, 100)
ax.set_ylim(0, min(100, np.ceil(error_rates.max() * 100)))
ax.set_aspect("equal", "box")
ax.legend(loc="upper right")
fig.tight_layout()
return fig, ax


Expand Down Expand Up @@ -270,9 +268,7 @@ def plot(
ax.set_ylabel("Generalized Risk (%)", fontsize=16)
ax.set_xlim(0, 100)
ax.set_ylim(0, min(100, np.ceil(error_rates.max() * 100)))
ax.set_aspect("equal", "box")
ax.legend(loc="upper right")
fig.tight_layout()
return fig, ax


Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/models/resnet/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def __init__(
self.layer4 = nn.Identity()
linear_multiplier = 4

self.dropout = nn.Dropout(p=dropout_rate)
self.final_dropout = nn.Dropout(p=dropout_rate)
self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten(1)

Expand Down Expand Up @@ -297,7 +297,7 @@ def forward(self, x: Tensor) -> Tensor:
out = self.layer3(out)
out = self.layer4(out)
out = self.pool(out)
out = self.dropout(self.flatten(out))
out = self.final_dropout(self.flatten(out))
return self.linear(out)


Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/models/resnet/lpbnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __init__(
self.layer4 = nn.Identity()
linear_multiplier = 4

self.dropout = nn.Dropout(p=dropout_rate)
self.final_dropout = nn.Dropout(p=dropout_rate)
self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten(1)

Expand Down Expand Up @@ -309,7 +309,7 @@ def feats_forward(self, x: Tensor) -> Tensor:
out = self.layer3(out)
out = self.layer4(out)
out = self.pool(out)
return self.dropout(self.flatten(out))
return self.final_dropout(self.flatten(out))

def forward(self, x: Tensor) -> Tensor:
return self.linear(self.feats_forward(x))
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/models/resnet/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def __init__(
self.layer4 = nn.Identity()
linear_multiplier = 4

self.dropout = nn.Dropout(p=dropout_rate)
self.final_dropout = nn.Dropout(p=dropout_rate)
self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten(1)

Expand Down Expand Up @@ -315,7 +315,7 @@ def forward(self, x: Tensor) -> Tensor:
out = self.layer4(out)

out = self.pool(out)
out = self.dropout(self.flatten(out))
out = self.final_dropout(self.flatten(out))
return self.linear(out)


Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/models/resnet/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def __init__(
self.layer4 = nn.Identity()
linear_multiplier = 4

self.dropout = nn.Dropout(p=dropout_rate)
self.final_dropout = nn.Dropout(p=dropout_rate)
self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten(1)

Expand Down Expand Up @@ -374,7 +374,7 @@ def forward(self, x: Tensor) -> Tensor:
)

out = self.pool(out)
out = self.dropout(self.flatten(out))
out = self.final_dropout(self.flatten(out))
return self.linear(out)

def check_config(self, config: dict[str, Any]) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions torch_uncertainty/models/resnet/std.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(
self.layer4 = nn.Identity()
linear_multiplier = 4

self.dropout = nn.Dropout(p=dropout_rate)
self.final_dropout = nn.Dropout(p=dropout_rate)
self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten(1)

Expand Down Expand Up @@ -340,7 +340,7 @@ def feats_forward(self, x: Tensor) -> Tensor:
out = self.layer3(out)
out = self.layer4(out)
out = self.pool(out)
return self.dropout(self.flatten(out))
return self.final_dropout(self.flatten(out))

def forward(self, x: Tensor) -> Tensor:
return self.linear(self.feats_forward(x))
Expand Down Expand Up @@ -374,6 +374,7 @@ def resnet(
activation_fn (Callable, optional): Activation function. Defaults to
``torch.nn.functional.relu``.
normalization_layer (nn.Module, optional): Normalization layer.
Defaults to ``torch.nn.BatchNorm2d``.
Returns:
_ResNet: The ResNet model.
Expand Down
25 changes: 20 additions & 5 deletions torch_uncertainty/models/wideresnet/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
groups: int,
conv_bias: bool,
activation_fn: Callable,
normalization_layer: type[nn.Module],
) -> None:
super().__init__()
self.activation_fn = activation_fn
Expand All @@ -35,7 +36,7 @@ def __init__(
bias=conv_bias,
)
self.dropout = nn.Dropout2d(p=dropout_rate)
self.bn1 = nn.BatchNorm2d(planes)
self.bn1 = normalization_layer(planes)
self.conv2 = BatchConv2d(
planes,
planes,
Expand All @@ -46,7 +47,7 @@ def __init__(
groups=groups,
bias=conv_bias,
)
self.bn2 = nn.BatchNorm2d(planes)
self.bn2 = normalization_layer(planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
groups: int = 1,
style: Literal["imagenet", "cifar"] = "imagenet",
activation_fn: Callable = relu,
normalization_layer: type[nn.Module] = nn.BatchNorm2d,
) -> None:
super().__init__()
self.num_estimators = num_estimators
Expand Down Expand Up @@ -123,7 +125,7 @@ def __init__(
else:
raise ValueError(f"Unknown WideResNet style: {style}. ")

self.bn1 = nn.BatchNorm2d(num_stages[0])
self.bn1 = normalization_layer(num_stages[0])

if style == "imagenet":
self.optional_pool = nn.MaxPool2d(
Expand All @@ -142,6 +144,7 @@ def __init__(
groups=groups,
conv_bias=conv_bias,
activation_fn=activation_fn,
normalization_layer=normalization_layer,
)
self.layer2 = self._wide_layer(
_WideBasicBlock,
Expand All @@ -153,6 +156,7 @@ def __init__(
groups=groups,
conv_bias=conv_bias,
activation_fn=activation_fn,
normalization_layer=normalization_layer,
)
self.layer3 = self._wide_layer(
_WideBasicBlock,
Expand All @@ -164,9 +168,10 @@ def __init__(
groups=groups,
conv_bias=conv_bias,
activation_fn=activation_fn,
normalization_layer=normalization_layer,
)

self.dropout = nn.Dropout(p=dropout_rate)
self.final_dropout = nn.Dropout(p=dropout_rate)
self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten(1)
self.linear = BatchLinear(
Expand All @@ -186,6 +191,7 @@ def _wide_layer(
groups: int,
conv_bias: bool,
activation_fn: Callable,
normalization_layer: type[nn.Module],
) -> nn.Module:
strides = [stride] + [1] * (int(num_blocks) - 1)
layers = []
Expand All @@ -201,6 +207,7 @@ def _wide_layer(
num_estimators=num_estimators,
groups=groups,
activation_fn=activation_fn,
normalization_layer=normalization_layer,
)
)
self.in_planes = planes
Expand All @@ -214,7 +221,7 @@ def feats_forward(self, x: Tensor) -> Tensor:
out = self.layer2(out)
out = self.layer3(out)
out = self.pool(out)
return self.dropout(self.flatten(out))
return self.final_dropout(self.flatten(out))

def forward(self, x: Tensor) -> Tensor:
return self.linear(self.feats_forward(x))
Expand All @@ -228,6 +235,8 @@ def batched_wideresnet28x10(
dropout_rate: float = 0.3,
groups: int = 1,
style: Literal["imagenet", "cifar"] = "imagenet",
activation_fn: Callable = relu,
normalization_layer: type[nn.Module] = nn.BatchNorm2d,
) -> _BatchWideResNet:
"""BatchEnsemble of Wide-ResNet-28x10.
Expand All @@ -241,6 +250,10 @@ def batched_wideresnet28x10(
groups (int): Number of groups in the convolutions. Defaults to ``1``.
style (bool, optional): Whether to use the ImageNet
structure. Defaults to ``True``.
activation_fn (Callable, optional): Activation function. Defaults to
``torch.nn.functional.relu``.
normalization_layer (nn.Module, optional): Normalization layer.
Defaults to ``torch.nn.BatchNorm2d``.
Returns:
_BatchWideResNet: A BatchEnsemble-style Wide-ResNet-28x10.
Expand All @@ -255,4 +268,6 @@ def batched_wideresnet28x10(
num_estimators=num_estimators,
groups=groups,
style=style,
activation_fn=activation_fn,
normalization_layer=normalization_layer,
)
31 changes: 23 additions & 8 deletions torch_uncertainty/models/wideresnet/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
scale: float,
groups: int,
activation_fn: Callable,
normalization_layer: type[nn.Module],
) -> None:
super().__init__()
self.activation_fn = activation_fn
Expand All @@ -37,7 +38,7 @@ def __init__(
bias=conv_bias,
)
self.dropout = nn.Dropout2d(p=dropout_rate)
self.bn1 = nn.BatchNorm2d(planes)
self.bn1 = normalization_layer(planes)
self.conv2 = MaskedConv2d(
planes,
planes,
Expand All @@ -49,7 +50,7 @@ def __init__(
groups=groups,
bias=conv_bias,
)
self.bn2 = nn.BatchNorm2d(planes)
self.bn2 = normalization_layer(planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
groups: int = 1,
style: Literal["imagenet", "cifar"] = "imagenet",
activation_fn: Callable = relu,
normalization_layer: type[nn.Module] = nn.BatchNorm2d,
) -> None:
super().__init__()
self.num_estimators = num_estimators
Expand Down Expand Up @@ -126,7 +128,7 @@ def __init__(
else:
raise ValueError(f"Unknown WideResNet style: {style}. ")

self.bn1 = nn.BatchNorm2d(num_stages[0])
self.bn1 = normalization_layer(num_stages[0])

if style == "imagenet":
self.optional_pool = nn.MaxPool2d(
Expand All @@ -146,6 +148,7 @@ def __init__(
scale=scale,
groups=groups,
activation_fn=activation_fn,
normalization_layer=normalization_layer,
)
self.layer2 = self._wide_layer(
_WideBasicBlock,
Expand All @@ -158,6 +161,7 @@ def __init__(
scale=scale,
groups=groups,
activation_fn=activation_fn,
normalization_layer=normalization_layer,
)
self.layer3 = self._wide_layer(
_WideBasicBlock,
Expand All @@ -170,9 +174,10 @@ def __init__(
scale=scale,
groups=groups,
activation_fn=activation_fn,
normalization_layer=normalization_layer,
)

self.dropout = nn.Dropout(p=dropout_rate)
self.final_dropout = nn.Dropout(p=dropout_rate)
self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten(1)

Expand All @@ -189,9 +194,10 @@ def _wide_layer(
dropout_rate: float,
stride: int,
num_estimators: int,
scale: float = 2.0,
groups: int = 1,
activation_fn: Callable = relu,
scale: float,
groups: int,
activation_fn: Callable,
normalization_layer: type[nn.Module],
) -> nn.Module:
strides = [stride] + [1] * (int(num_blocks) - 1)
layers = []
Expand All @@ -208,6 +214,7 @@ def _wide_layer(
scale=scale,
groups=groups,
activation_fn=activation_fn,
normalization_layer=normalization_layer,
)
)
self.in_planes = planes
Expand All @@ -221,7 +228,7 @@ def feats_forward(self, x: Tensor) -> Tensor:
out = self.layer2(out)
out = self.layer3(out)
out = self.pool(out)
return self.dropout(self.flatten(out))
return self.final_dropout(self.flatten(out))

def forward(self, x: Tensor) -> Tensor:
return self.linear(self.feats_forward(x))
Expand All @@ -236,6 +243,8 @@ def masked_wideresnet28x10(
dropout_rate: float = 0.3,
groups: int = 1,
style: Literal["imagenet", "cifar"] = "imagenet",
activation_fn: Callable = relu,
normalization_layer: type[nn.Module] = nn.BatchNorm2d,
) -> _MaskedWideResNet:
"""Masksembles of Wide-ResNet-28x10.
Expand All @@ -251,6 +260,10 @@ def masked_wideresnet28x10(
``1``.
style (bool, optional): Whether to use the ImageNet
structure. Defaults to ``True``.
activation_fn (Callable, optional): Activation function. Defaults to
``torch.nn.functional.relu``.
normalization_layer (nn.Module, optional): Normalization layer.
Defaults to ``torch.nn.BatchNorm2d``.
Returns:
_MaskedWideResNet: A Masksembles-style Wide-ResNet-28x10.
Expand All @@ -266,4 +279,6 @@ def masked_wideresnet28x10(
scale=scale,
groups=groups,
style=style,
activation_fn=activation_fn,
normalization_layer=normalization_layer,
)
Loading

0 comments on commit 7df97be

Please sign in to comment.