Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

389 add torchvision resnet50 support #390

Merged
merged 3 commits into from
May 15, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 110 additions & 4 deletions generative/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import torch
import torch.nn as nn
from lpips import LPIPS
from torchvision.models import ResNet50_Weights, resnet50
from torchvision.models.feature_extraction import create_feature_extractor


class PerceptualLoss(nn.Module):
Expand All @@ -22,20 +24,29 @@ class PerceptualLoss(nn.Module):
pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep
features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An
Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"
https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; and MedicalNet from Chen et al. "Med3D: Transfer Learning for
3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 .
https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; MedicalNet from Chen et al. "Med3D: Transfer Learning for
3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 ;
and ResNet50 from Torchvision: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html .

The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the
three axis.

Args:
spatial_dims: number of spatial dimensions.
network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``,
``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``, ``"resnet50"``}
Specifies the network architecture to use. Defaults to ``"alex"``.
is_fake_3d: if True use 2.5D approach for a 3D perceptual loss.
fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach.
cache_dir: path to cache directory to save the pretrained network weights.
pretrained: whether to load pretrained weights. This argument only works when using networks from
LIPIS or Torchvision. Defaults to ``"True"``.
pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded
via using this argument. This argument only works when ``"network_type"`` is "resnet50".
Defaults to `None`.
pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50".
Defaults to `None`.
"""

def __init__(
Expand All @@ -45,6 +56,9 @@ def __init__(
is_fake_3d: bool = True,
fake_3d_ratio: float = 0.5,
cache_dir: str | None = None,
pretrained: bool = True,
pretrained_path: str | None = None,
pretrained_state_dict_key: str | None = None,
):
super().__init__()

Expand All @@ -65,8 +79,15 @@ def __init__(
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
Nic-Ma marked this conversation as resolved.
Show resolved Hide resolved
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
elif network_type == "resnet50":
self.perceptual_function = TorchvisionModelPerceptualSimilarity(
net=network_type,
pretrained=pretrained,
pretrained_path=pretrained_path,
pretrained_state_dict_key=pretrained_state_dict_key,
)
else:
self.perceptual_function = LPIPS(pretrained=True, net=network_type, verbose=False)
self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False)
self.is_fake_3d = is_fake_3d
self.fake_3d_ratio = fake_3d_ratio

Expand Down Expand Up @@ -247,10 +268,95 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return results


class TorchvisionModelPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with TorchVision models.
Currently, only ResNet50 is supported. The network structure is based on:
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html

Args:
net: {``"resnet50"``}
Specifies the network architecture to use. Defaults to ``"resnet50"``.
pretrained: whether to load pretrained weights. Defaults to `True`.
pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded
via using this argument. Defaults to `None`.
pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
extract the expected state dict. Defaults to `None`.
"""

def __init__(
self,
net: str = "resnet50",
pretrained: bool = True,
pretrained_path: str | None = None,
pretrained_state_dict_key: str | None = None,
) -> None:
super().__init__()
supported_networks = ["resnet50"]
if net not in supported_networks:
raise NotImplementedError(
f"'net' {net} is not supported, please select a network from {supported_networks}."
)

if pretrained_path is None:
network = resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None)
else:
network = resnet50(weights=None)
if pretrained is True:
state_dict = torch.load(pretrained_path)
if pretrained_state_dict_key is not None:
state_dict = state_dict[pretrained_state_dict_key]
network.load_state_dict(state_dict)
self.final_layer = "layer4.2.relu_2"
self.model = create_feature_extractor(network, [self.final_layer])
self.eval()

for param in self.parameters():
param.requires_grad = False

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights,
we make sure that the input and target have 3 channels, and then do Z-Score normalization.
The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package).
"""
# If input has just 1 channel, repeat channel to have 3 channels
if input.shape[1] == 1 and target.shape[1] == 1:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)

# Input normalization
input = torchvision_zscore_norm(input)
target = torchvision_zscore_norm(target)

# Get model outputs
outs_input = self.model.forward(input)[self.final_layer]
outs_target = self.model.forward(target)[self.final_layer]

# Normalise through the channels
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)

results = (feats_input - feats_target) ** 2
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)

return results


def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
return x.mean([2, 3], keepdim=keepdim)


def torchvision_zscore_norm(x: torch.Tensor) -> torch.Tensor:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x[:, 0, :, :] = (x[:, 0, :, :] - mean[0]) / std[0]
x[:, 1, :, :] = (x[:, 1, :, :] - mean[1]) / std[1]
x[:, 2, :, :] = (x[:, 2, :, :] - mean[2]) / std[2]
return x


def subtract_mean(x: torch.Tensor) -> torch.Tensor:
mean = [0.406, 0.456, 0.485]
x[:, 0, :, :] -= mean[0]
Expand Down