From c8667c860a8fcb6d7ff3684334687f2104c89b9a Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 21 Dec 2022 21:43:58 +0000 Subject: [PATCH 1/3] Add RadImageNet component Signed-off-by: Walter Hugo Lopez Pinaya --- generative/losses/perceptual.py | 75 ++++++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 7 deletions(-) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 682c1e76..559e7413 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -18,8 +18,6 @@ # TODO: Define model_path for lpips networks. # TODO: Add MedicalNet for true 3D computation (https://github.com/Tencent/MedicalNet) -# TODO: Add RadImageNet for 2D computation with networks pretrained using radiological images -# (https://github.com/BMEII-AI/RadImageNet) class PerceptualLoss(nn.Module): """ Perceptual loss using features from pretrained deep neural networks trained. The function supports networks @@ -52,11 +50,14 @@ def __init__( raise NotImplementedError("True 3D perceptual loss is not implemented. Try setting is_fake_3d=False") self.spatial_dims = spatial_dims - self.perceptual_function = LPIPS( - pretrained=True, - net=network_type, - verbose=False, - ) + if "radimagenet_" in network_type: + self.perceptual_function = RadimagenetPerceptualComponent(net=network_type, verbose=False) + else: + self.perceptual_function = LPIPS( + pretrained=True, + net=network_type, + verbose=False, + ) self.is_fake_3d = is_fake_3d self.fake_3d_ratio = fake_3d_ratio @@ -129,3 +130,63 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss = loss_sagittal + loss_axial + loss_coronal return torch.mean(loss) + + +class RadimagenetPerceptualComponent(nn.Module): + def __init__( + self, + net: str = "radimagenet_resnet50", + verbose: bool = False, + ): + super().__init__() + self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose) + self.eval() + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at + https://github.com/BMEII-AI/RadImageNet/blob/ec1f78cd57909399f9b6c39a412281027cd12a07/breast/breast_train.py#L154 + we remove the mean components of each input data channel. + """ + # If input has just 1 channel, repeat channel to have 3 channels + if input.shape[1] == 1 and target.shape[1] == 1: + input.repeat(1, 3, 1, 1) + target.repeat(1, 3, 1, 1) + + # Change order from 'RGB' to 'BGR' + input = input[:, [2, 1, 0], ...] + target = target[:, [2, 1, 0], ...] + + # Subtract mean used during training + input = subtract_mean(input) + target = subtract_mean(target) + + # Get model outputs + outs_input = self.model.forward(input) + outs_target = self.model.forward(target) + + # Normalise through the channels + feats_input = normalize_tensor(outs_input) + feats_target = normalize_tensor(outs_target) + + results = (feats_input - feats_target) ** 2 + results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) + + return results + + +def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: + return x.mean([2, 3], keepdim=keepdim) + + +def subtract_mean(x: torch.Tensor) -> torch.Tensor: + mean = [0.406, 0.456, 0.485] + x[:, 0, :, :] -= mean[0] + x[:, 1, :, :] -= mean[1] + x[:, 2, :, :] -= mean[2] + return x + + +def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) From 90635b65042742a017da65c24718fa8a139c8e5f Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Thu, 22 Dec 2022 13:49:24 +0000 Subject: [PATCH 2/3] Add tests and update docstrings Signed-off-by: Walter Hugo Lopez Pinaya --- generative/losses/perceptual.py | 34 +++++++++++++++++++++++---------- tests/test_perceptual_loss.py | 10 ++++++++++ 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 559e7413..8eb23ee2 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -21,14 +21,16 @@ class PerceptualLoss(nn.Module): """ Perceptual loss using features from pretrained deep neural networks trained. The function supports networks - pretrained on ImageNet that use the LPIPS approach from: Zhang, et al. "The unreasonable effectiveness of deep - features as a perceptual metric." https://arxiv.org/abs/1801.03924 + pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep + features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An + Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning" . + The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the three axis. Args: spatial_dims: number of spatial dimensions. - network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``} + network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``} Specifies the network architecture to use. Defaults to ``"alex"``. is_fake_3d: if True use 2.5D approach for a 3D perceptual loss. fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach. @@ -51,7 +53,7 @@ def __init__( self.spatial_dims = spatial_dims if "radimagenet_" in network_type: - self.perceptual_function = RadimagenetPerceptualComponent(net=network_type, verbose=False) + self.perceptual_function = RadImageNetPerceptualComponent(net=network_type, verbose=False) else: self.perceptual_function = LPIPS( pretrained=True, @@ -132,12 +134,23 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return torch.mean(loss) -class RadimagenetPerceptualComponent(nn.Module): +class RadImageNetPerceptualComponent(nn.Module): + """ + Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et + al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class + uses torch Hub to download the networks from "Warvito/radimagenet-models". + + Args: + net: {``"radimagenet_resnet50"``} + Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``. + verbose: if false, mute messages from torch Hub load function. + """ + def __init__( self, net: str = "radimagenet_resnet50", verbose: bool = False, - ): + ) -> None: super().__init__() self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose) self.eval() @@ -145,13 +158,14 @@ def __init__( def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at - https://github.com/BMEII-AI/RadImageNet/blob/ec1f78cd57909399f9b6c39a412281027cd12a07/breast/breast_train.py#L154 - we remove the mean components of each input data channel. + https://github.com/BMEII-AI/RadImageNet, we make sure that the input and target have 3 channels, reorder it from + 'RGB' to 'BGR', and then remove the mean components of each input data channel. The outputs are normalised + across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package). """ # If input has just 1 channel, repeat channel to have 3 channels if input.shape[1] == 1 and target.shape[1] == 1: - input.repeat(1, 3, 1, 1) - target.repeat(1, 3, 1, 1) + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) # Change order from 'RGB' to 'BGR' input = input[:, [2, 1, 0], ...] diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index a1e8b098..0a8921db 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -28,6 +28,16 @@ (2, 1, 64, 64, 64), (2, 1, 64, 64, 64), ], + [ + {"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, + (2, 1, 64, 64), + (2, 1, 64, 64), + ], + [ + {"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], ] From c30a8017eca5955568299eb7ebc3d1c87b4691e2 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Thu, 22 Dec 2022 13:53:42 +0000 Subject: [PATCH 3/3] Add tests with 3 channels Signed-off-by: Walter Hugo Lopez Pinaya --- tests/test_perceptual_loss.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 0a8921db..dadd02bf 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -33,6 +33,11 @@ (2, 1, 64, 64), (2, 1, 64, 64), ], + [ + {"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, + (2, 3, 64, 64), + (2, 3, 64, 64), + ], [ {"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1}, (2, 1, 64, 64, 64),