Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RadImageNet to Perceptual Loss #153

Merged
merged 4 commits into from
Jan 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 79 additions & 8 deletions generative/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@
class PerceptualLoss(nn.Module):
"""
Perceptual loss using features from pretrained deep neural networks trained. The function supports networks
pretrained on ImageNet that use the LPIPS approach from: Zhang, et al. "The unreasonable effectiveness of deep
features as a perceptual metric." https://arxiv.org/abs/1801.03924
pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep
features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An
Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"; and MedicalNet from Chen et al.
"Med3D: Transfer Learning for 3D Medical Image Analysis" .

The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the
three axis.

Args:
spatial_dims: number of spatial dimensions.
network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"medicalnet_resnet10_23datasets"``,
``"medicalnet_resnet50_23datasets"``}
network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``,
Warvito marked this conversation as resolved.
Show resolved Hide resolved
``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
Specifies the network architecture to use. Defaults to ``"alex"``.
is_fake_3d: if True use 2.5D approach for a 3D perceptual loss.
fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach.
Expand All @@ -48,6 +51,8 @@ def __init__(
self.spatial_dims = spatial_dims
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualComponent(net=network_type, verbose=False)
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualComponent(net=network_type, verbose=False)
else:
self.perceptual_function = LPIPS(
pretrained=True,
Expand Down Expand Up @@ -116,15 +121,14 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

if self.spatial_dims == 2:
loss = self.perceptual_function(input, target)
elif self.spatial_dims == 3 and self.is_fake_3d:
if self.spatial_dims == 3 and self.is_fake_3d:
Warvito marked this conversation as resolved.
Show resolved Hide resolved
# Compute 2.5D approach
loss_sagittal = self._calculate_axis_loss(input, target, spatial_axis=2)
loss_coronal = self._calculate_axis_loss(input, target, spatial_axis=3)
loss_axial = self._calculate_axis_loss(input, target, spatial_axis=4)
loss = loss_sagittal + loss_axial + loss_coronal
if self.spatial_dims == 3 and self.is_fake_3d is False:
else:
# 2D and real 3D cases
loss = self.perceptual_function(input, target)

return torch.mean(loss)
Expand Down Expand Up @@ -194,3 +198,70 @@ def medicalnet_intensity_normalisation(volume):
mean = volume.mean()
std = volume.std()
return (volume - mean) / std


class RadImageNetPerceptualComponent(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et
al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class
uses torch Hub to download the networks from "Warvito/radimagenet-models".

Args:
net: {``"radimagenet_resnet50"``}
Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``.
verbose: if false, mute messages from torch Hub load function.
"""

def __init__(
self,
net: str = "radimagenet_resnet50",
verbose: bool = False,
) -> None:
super().__init__()
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose)
self.eval()

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Warvito marked this conversation as resolved.
Show resolved Hide resolved
"""
We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at
https://github.com/BMEII-AI/RadImageNet, we make sure that the input and target have 3 channels, reorder it from
'RGB' to 'BGR', and then remove the mean components of each input data channel. The outputs are normalised
across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package).
"""
# If input has just 1 channel, repeat channel to have 3 channels
if input.shape[1] == 1 and target.shape[1] == 1:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)

# Change order from 'RGB' to 'BGR'
input = input[:, [2, 1, 0], ...]
target = target[:, [2, 1, 0], ...]

# Subtract mean used during training
input = subtract_mean(input)
target = subtract_mean(target)

# Get model outputs
outs_input = self.model.forward(input)
outs_target = self.model.forward(target)

# Normalise through the channels
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)

results = (feats_input - feats_target) ** 2
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)

return results


def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
return x.mean([2, 3], keepdim=keepdim)


def subtract_mean(x: torch.Tensor) -> torch.Tensor:
mean = [0.406, 0.456, 0.485]
x[:, 0, :, :] -= mean[0]
x[:, 1, :, :] -= mean[1]
x[:, 2, :, :] -= mean[2]
return x
15 changes: 15 additions & 0 deletions tests/test_perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
[
{"spatial_dims": 2, "network_type": "radimagenet_resnet50"},
(2, 1, 64, 64),
(2, 1, 64, 64),
],
[
{"spatial_dims": 2, "network_type": "radimagenet_resnet50"},
(2, 3, 64, 64),
(2, 3, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1},
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False},
(2, 1, 64, 64, 64),
Expand Down