From 26c9d56a19904d953e69e9872d62c547ddc7d962 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 28 Mar 2024 11:47:43 +0100 Subject: [PATCH 01/10] Adding 3D UNet --- direct/nn/unet/unet_3d.py | 439 +++++++++++++++++++++++++++++++++ tests/tests_nn/test_unet_3d.py | 53 ++++ 2 files changed, 492 insertions(+) create mode 100644 direct/nn/unet/unet_3d.py create mode 100644 tests/tests_nn/test_unet_3d.py diff --git a/direct/nn/unet/unet_3d.py b/direct/nn/unet/unet_3d.py new file mode 100644 index 000000000..0fd9cd3f7 --- /dev/null +++ b/direct/nn/unet/unet_3d.py @@ -0,0 +1,439 @@ +# Copyright (c) DIRECT Contributors + +"""Code for three-dimensional U-Net adapted from the 2D variant.""" + +from __future__ import annotations + +import math + +import torch +from torch import nn +from torch.nn import functional as F + + +class ConvBlock3D(nn.Module): + """3D U-Net convolutional block.""" + + def __init__(self, in_channels: int, out_channels: int, dropout_probability: float): + """Inits :class:`ConvBlock3D`. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolutional layers. + dropout_probability : float + Dropout probability applied after convolutional layers. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.dropout_probability = dropout_probability + + self.layers = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout3d(dropout_probability), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout3d(dropout_probability), + ) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + """Performs the forward pass of :class:`ConvBlock3D`.. + + Parameters + ---------- + input_data : torch.Tensor + Input data. + + Returns + ------- + torch.Tensor + """ + return self.layers(input_data) + + +class TransposeConvBlock3D(nn.Module): + """3D U-Net Transpose Convolutional Block.""" + + def __init__(self, in_channels: int, out_channels: int): + """Inits :class:`TransposeConvBlock3D`. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolutional layers. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.layers = nn.Sequential( + nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + """Performs the forward pass of :class:`TransposeConvBlock3D`. + + Parameters + ---------- + input_data : torch.Tensor + Input data. + + Returns + ------- + torch.Tensor + """ + return self.layers(input_data) + + +class UnetModel3d(nn.Module): + """PyTorch implementation of a 3D U-Net model. + + This class defines a 3D U-Net architecture consisting of down-sampling and up-sampling layers with 3D convolutional + blocks. This is an extension of Unet2dModel, but for volumes. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_filters: int, + num_pool_layers: int, + dropout_probability: float, + ): + """Inits :class:`UnetModel3d`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + num_filters : int + Number of output channels of the first convolutional layer. + num_pool_layers : int + Number of down-sampling and up-sampling layers (depth). + dropout_probability : float + Dropout probability. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_filters = num_filters + self.num_pool_layers = num_pool_layers + self.dropout_probability = dropout_probability + + self.down_sample_layers = nn.ModuleList([ConvBlock3D(in_channels, num_filters, dropout_probability)]) + ch = num_filters + for _ in range(num_pool_layers - 1): + self.down_sample_layers += [ConvBlock3D(ch, ch * 2, dropout_probability)] + ch *= 2 + self.conv = ConvBlock3D(ch, ch * 2, dropout_probability) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv += [TransposeConvBlock3D(ch * 2, ch)] + self.up_conv += [ConvBlock3D(ch * 2, ch, dropout_probability)] + ch //= 2 + + self.up_transpose_conv += [TransposeConvBlock3D(ch * 2, ch)] + self.up_conv += [ + nn.Sequential( + ConvBlock3D(ch * 2, ch, dropout_probability), + nn.Conv3d(ch, out_channels, kernel_size=1, stride=1), + ) + ] + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`UnetModel3d`. + + Parameters + ---------- + input_data : torch.Tensor + Input tensor of shape (N, in_channels, slice/time, height, width). + + Returns + ------- + torch.Tensor + Output of shape (N, out_channels, slice/time, height, width). + """ + stack = [] + output, inp_pad = pad_to_pow_of_2(input_data, self.num_pool_layers) + + # Apply down-sampling layers + for _, layer in enumerate(self.down_sample_layers): + output = layer(output) + stack.append(output) + output = F.avg_pool3d(output, kernel_size=2, stride=2, padding=0) + + output = self.conv(output) + + # Apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + padding = [0, 0, 0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 + if output.shape[-3] != downsample_layer.shape[-3]: + padding[5] = 1 + if sum(padding) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + if sum(inp_pad) != 0: + output = output[ + :, + :, + inp_pad[4] : output.shape[2] - inp_pad[5], + inp_pad[2] : output.shape[3] - inp_pad[3], + inp_pad[0] : output.shape[4] - inp_pad[1], + ] + + return output + + +class NormUnetModel3d(nn.Module): + """Implementation of a Normalized U-Net model for 3D data.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + num_filters: int, + num_pool_layers: int, + dropout_probability: float, + norm_groups: int = 2, + ): + """Inits :class:`NormUnetModel3D`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + num_filters : int + Number of output channels of the first convolutional layer. + num_pool_layers : int + Number of down-sampling and up-sampling layers (depth). + dropout_probability : float + Dropout probability. + norm_groups: int, + Number of normalization groups. + """ + super().__init__() + + self.unet3d = UnetModel3d( + in_channels=in_channels, + out_channels=out_channels, + num_filters=num_filters, + num_pool_layers=num_pool_layers, + dropout_probability=dropout_probability, + ) + + self.norm_groups = norm_groups + + @staticmethod + def norm(input_data: torch.Tensor, groups: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Applies group normalization for 3D data. + + Parameters + ---------- + input_data : torch.Tensor + The input tensor to normalize. + groups : int + The number of groups to divide the tensor into for normalization. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing the normalized tensor, the mean, and the standard deviation used for normalization. + """ + # Group norm + b, c, z, h, w = input_data.shape + input_data = input_data.reshape(b, groups, -1) + + mean = input_data.mean(-1, keepdim=True) + std = input_data.std(-1, keepdim=True) + + output = (input_data - mean) / std + output = output.reshape(b, c, z, h, w) + + return output, mean, std + + @staticmethod + def unnorm( + input_data: torch.Tensor, + mean: torch.Tensor, + std: torch.Tensor, + groups: int, + ) -> torch.Tensor: + """Reverts the normalization applied to the 3D tensor. + + Parameters + ---------- + input_data : torch.Tensor + The normalized tensor to revert normalization on. + mean : torch.Tensor + The mean used during normalization. + std : torch.Tensor + The standard deviation used during normalization. + groups : int + The number of groups the tensor was divided into during normalization. + + Returns + ------- + torch.Tensor + The tensor after reverting the normalization. + """ + b, c, z, h, w = input_data.shape + input_data = input_data.reshape(b, groups, -1) + return (input_data * std + mean).reshape(b, c, z, h, w) + + @staticmethod + def pad( + input_data: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[list[int], list[int], int, int, list[int], list[int]]]: + """Applies padding to the input 3D tensor to ensure its dimensions are multiples of 16. + + Parameters + ---------- + input_data : torch.Tensor + The input tensor to pad. + + Returns + ------- + tuple[torch.Tensor, tuple[list[int], list[int], int, int, list[int], list[int]]] + A tuple containing the padded tensor and a tuple with the padding applied to each dimension + (height, width, depth) and the target dimensions after padding. + """ + _, _, z, h, w = input_data.shape + w_mult = ((w - 1) | 15) + 1 + h_mult = ((h - 1) | 15) + 1 + z_mult = ((z - 1) | 15) + 1 + w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] + h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] + z_pad = [math.floor((z_mult - z) / 2), math.ceil((z_mult - z) / 2)] + + output = F.pad(input_data, w_pad + h_pad + z_pad) + return output, (h_pad, w_pad, z_pad, h_mult, w_mult, z_mult) + + @staticmethod + def unpad( + input_data: torch.Tensor, + h_pad: list[int], + w_pad: list[int], + z_pad: list[int], + h_mult: int, + w_mult: int, + z_mult: int, + ) -> torch.Tensor: + """Removes padding from the 3D input tensor, reverting it to its original dimensions before padding was applied. + + This method is typically used after the model has processed the padded input. + + Parameters + ---------- + input_data : torch.Tensor + The tensor from which padding will be removed. + h_pad : list[int] + Padding applied to the height, specified as [top, bottom]. + w_pad : list[int] + Padding applied to the width, specified as [left, right]. + z_pad : list[int] + Padding applied to the depth, specified as [front, back]. + h_mult : int + The height as computed in the `pad` method. + w_mult : int + The width as computed in the `pad` method. + z_mult : int + The depth as computed in the `pad` method. + + Returns + ------- + torch.Tensor + The tensor with padding removed, restored to its original dimensions. + """ + return input_data[ + ..., z_pad[0] : z_mult - z_pad[1], h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1] + ] + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + """Performs the forward pass of :class:`NormUnetModel3D`. + + Parameters + ---------- + input_data : torch.Tensor + Input tensor of shape (N, in_channels, slice/time, height, width). + + Returns + ------- + torch.Tensor + Output of shape (N, out_channels, slice/time, height, width). + """ + output, mean, std = self.norm(input_data, self.norm_groups) + output, pad_sizes = self.pad(output) + output = self.unet3d(output) + + output = self.unpad(output, *pad_sizes) + output = self.unnorm(output, mean, std, self.norm_groups) + + return output + + +def pad_to_pow_of_2(inp: torch.Tensor, k: int) -> tuple[torch.Tensor, list[int]]: + """Pads the input tensor along the spatial dimensions (depth, height, width) to the nearest power of 2. + + This is necessary for certain operations in the 3D U-Net architecture to maintain dimensionality. + + Parameters + ---------- + inp : torch.Tensor + The input tensor to be padded. + k : int + The exponent to which the base of 2 is raised to determine the padding. Used to calculate + the target dimension size as a power of 2. + + Returns + ------- + tuple[torch.Tensor, list[int]] + A tuple containing the padded tensor and a list of padding applied to each spatial dimension + in the format [depth_front, depth_back, height_top, height_bottom, width_left, width_right]. + + Examples + -------- + >>> inp = torch.rand(1, 1, 15, 15, 15) # A random tensor with shape [1, 1, 15, 15, 15] + >>> padded_inp, padding = pad_to_pow_of_2(inp, 4) + >>> print(padded_inp.shape, padding) + torch.Size([...]), [1, 1, 1, 1, 1, 1] + """ + diffs = [_ - 2**k for _ in inp.shape[2:]] + padding = [0, 0, 0, 0, 0, 0] + for i, diff in enumerate(diffs[::-1]): + if diff < 1: + padding[2 * i] = abs(diff) // 2 + padding[2 * i + 1] = abs(diff) - padding[2 * i] + + if sum(padding) > 0: + inp = F.pad(inp, padding) + + return inp, padding diff --git a/tests/tests_nn/test_unet_3d.py b/tests/tests_nn/test_unet_3d.py new file mode 100644 index 000000000..834a81343 --- /dev/null +++ b/tests/tests_nn/test_unet_3d.py @@ -0,0 +1,53 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + + +import numpy as np +import pytest +import torch + +from direct.nn.unet.unet_3d import NormUnetModel3d, UnetModel3d + + +def create_input(shape): + data = np.random.randn(*shape).copy() + data = torch.from_numpy(data).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [2, 3, 20, 16, 16], + [4, 2, 30, 20, 30], + [4, 2, 21, 24, 20], + ], +) +@pytest.mark.parametrize( + "num_filters", + [4, 6, 8], +) +@pytest.mark.parametrize( + "num_pool_layers", + [2, 3], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_unet_3d(shape, num_filters, num_pool_layers, normalized): + model_architecture = NormUnetModel3d if normalized else UnetModel3d + model = model_architecture( + in_channels=shape[1], + out_channels=shape[1], + num_filters=num_filters, + num_pool_layers=num_pool_layers, + dropout_probability=0.05, + ).cpu() + + data = create_input(shape).cpu() + + out = model(data) + + assert list(out.shape) == shape From a77ddf355cf9829ddaa448e49b61edb5f32277c3 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 28 Mar 2024 11:49:14 +0100 Subject: [PATCH 02/10] Adding 3D UNet config --- direct/nn/unet/config.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/direct/nn/unet/config.py b/direct/nn/unet/config.py index 19ccdc95e..7eca4c354 100644 --- a/direct/nn/unet/config.py +++ b/direct/nn/unet/config.py @@ -31,3 +31,12 @@ class Unet2dConfig(ModelConfig): skip_connection: bool = False normalized: bool = False image_initialization: str = "zero_filled" + + +@dataclass +class UnetModel3dConfig(ModelConfig): + in_channels: int = 2 + out_channels: int = 2 + num_filters: int = 16 + num_pool_layers: int = 4 + dropout_probability: float = 0.0 From ecdab9ec5e3d07063e63952ab8a927ad38927813 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 28 Mar 2024 11:50:08 +0100 Subject: [PATCH 03/10] Add 3d vsharp --- direct/nn/vsharp/config.py | 17 ++ direct/nn/vsharp/vsharp.py | 281 ++++++++++++++++++++++++++- direct/nn/vsharp/vsharp_engine.py | 128 ++++++++++++ tests/tests_nn/test_vsharp.py | 60 +++++- tests/tests_nn/test_vsharp_engine.py | 76 +++++++- 5 files changed, 547 insertions(+), 15 deletions(-) diff --git a/direct/nn/vsharp/config.py b/direct/nn/vsharp/config.py index 3acfd84f7..58e37d54e 100644 --- a/direct/nn/vsharp/config.py +++ b/direct/nn/vsharp/config.py @@ -34,3 +34,20 @@ class VSharpNetConfig(ModelConfig): image_conv_n_convs: int = 15 image_conv_activation: str = ActivationType.RELU image_conv_batchnorm: bool = False + + +@dataclass +class VSharpNet3DConfig(ModelConfig): + num_steps: int = 8 + num_steps_dc_gd: int = 6 + image_init: InitType = InitType.SENSE + no_parameter_sharing: bool = True + auxiliary_steps: int = -1 + initializer_channels: tuple[int, ...] = (32, 32, 64, 64) + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4) + initializer_multiscale: int = 1 + initializer_activation: ActivationType = ActivationType.PRELU + unet_num_filters: int = 32 + unet_num_pool_layers: int = 4 + unet_dropout: float = 0.0 + unet_norm: bool = False diff --git a/direct/nn/vsharp/vsharp.py b/direct/nn/vsharp/vsharp.py index cd5bbac20..3a6d95921 100644 --- a/direct/nn/vsharp/vsharp.py +++ b/direct/nn/vsharp/vsharp.py @@ -12,8 +12,6 @@ from __future__ import annotations -from typing import Callable - import numpy as np import torch import torch.nn.functional as F @@ -23,6 +21,7 @@ from direct.data.transforms import apply_mask, expand_operator, reduce_operator from direct.nn.get_nn_model_config import ModelName, _get_model_config, _get_relu_activation from direct.nn.types import ActivationType, InitType +from direct.nn.unet.unet_3d import NormUnetModel3d, UnetModel3d class LagrangeMultipliersInitializer(nn.Module): @@ -51,6 +50,8 @@ def __init__( Tuple of integers specifying the dilation factor for each convolutional layer in the network. multiscale_depth : int Number of multiscale features to include in the output. Default: 1. + activation : ActivationType + Activation function to use on the output. Default: ActivationType.PRELU. """ super().__init__() @@ -147,8 +148,8 @@ class VSharpNet(nn.Module): def __init__( self, - forward_operator: Callable, - backward_operator: Callable, + forward_operator: callable, + backward_operator: callable, num_steps: int, num_steps_dc_gd: int, image_init: InitType = InitType.SENSE, @@ -165,9 +166,9 @@ def __init__( Parameters ---------- - forward_operator : Callable + forward_operator : callable Forward operator function. - backward_operator : Callable + backward_operator : callable Backward operator function. num_steps : int Number of steps in the ADMM algorithm. @@ -271,13 +272,14 @@ def forward( masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). sensitivity_map: torch.Tensor - Sensitivity map of shape (N, coil, height, width, complex=2). Default: None. + Sensitivity map of shape (N, coil, height, width, complex=2). sampling_mask: torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). Returns ------- - image: torch.Tensor - Output image of shape (N, height, width, complex=2). + out : list of torch.Tensors + List of output images of shape (N, height, width, complex=2). """ out = [] if self.image_init == "sense": @@ -319,3 +321,264 @@ def forward( u = u + self.rho[admm_step] * (x - z) return out + + +class LagrangeMultipliersInitializer3D(torch.nn.Module): + """A convolutional neural network model that initializes the Lagrange multiplier of :class:`VSharpNet3D`.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + channels: tuple[int, ...], + dilations: tuple[int, ...], + multiscale_depth: int = 1, + activation: ActivationType = ActivationType.PRELU, + ): + """Initializes LagrangeMultipliersInitializer3D. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + channels : tuple of ints + Tuple of integers specifying the number of output channels for each convolutional layer in the network. + dilations : tuple of ints + Tuple of integers specifying the dilation factor for each convolutional layer in the network. + multiscale_depth : int + Number of multiscale features to include in the output. Default: 1. + activation : ActivationType + Activation function to use on the output. Default: ActivationType.PRELU. + """ + super().__init__() + + # Define convolutional blocks + self.conv_blocks = nn.ModuleList() + tch = in_channels + for curr_channels, curr_dilations in zip(channels, dilations): + block = nn.Sequential( + nn.ReplicationPad3d(curr_dilations), + nn.Conv3d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), + ) + tch = curr_channels + self.conv_blocks.append(block) + + # Define output block + tch = np.sum(channels[-multiscale_depth:]) + block = nn.Conv3d(tch, out_channels, 1, padding=0) + self.out_block = nn.Sequential(block) + + self.multiscale_depth = multiscale_depth + self.activation = _get_relu_activation(activation) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`LagrangeMultipliersInitializer3D`. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, in_channels, z, x, y). + + Returns + ------- + torch.Tensor + Output tensor of shape (batch_size, out_channels, z, x, y). + """ + + features = [] + for block in self.conv_blocks: + x = F.relu(block(x), inplace=True) + if self.multiscale_depth > 1: + features.append(x) + + if self.multiscale_depth > 1: + x = torch.cat(features[-self.multiscale_depth :], dim=1) + + return self.activation(self.out_block(x)) + + +class VSharpNet3D(nn.Module): + """VharpNet 3D version using 3D U-Nets as denoisers.""" + + def __init__( + self, + forward_operator: callable, + backward_operator: callable, + num_steps: int, + num_steps_dc_gd: int, + image_init: InitType = InitType.SENSE, + no_parameter_sharing: bool = True, + initializer_channels: tuple[int, ...] = (32, 32, 64, 64), + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4), + initializer_multiscale: int = 1, + initializer_activation: ActivationType = ActivationType.PRELU, + auxiliary_steps: int = -1, + unet_num_filters: int = 32, + unet_num_pool_layers: int = 4, + unet_dropout: float = 0.0, + unet_norm: bool = False, + **kwargs, + ): + """Inits :class:`VSharpNet3D`. + + Parameters + ---------- + forward_operator : callable + Forward operator function. + backward_operator : callable + Backward operator function. + num_steps : int + Number of steps in the ADMM algorithm. + num_steps_dc_gd : int + Number of steps in the Data Consistency using Gradient Descent step of ADMM. + image_init : str + Image initialization method. Default: 'sense'. + no_parameter_sharing : bool + Flag indicating whether parameter sharing is enabled in the denoiser blocks. + initializer_channels : tuple[int, ...] + Tuple of integers specifying the number of output channels for each convolutional layer in the + Lagrange multiplier initializer. Default: (32, 32, 64, 64). + initializer_dilations : tuple[int, ...] + Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier + initializer. Default: (1, 1, 2, 4). + initializer_multiscale : int + Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1. + initializer_activation : ActivationType + Activation type for the Lagrange multiplier initializer. Default: ActivationType.PReLU. + auxiliary_steps : int + Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`. + If -1, it uses all steps. If I, the last I steps will be used. + unet_num_filters : int + U-Net denoisers number of output channels of the first convolutional layer. Default: 32. + unet_num_pool_layers : int + U-Net denoisers number of down-sampling and up-sampling layers (depth). Default: 4. + unet_dropout : float + U-Net denoisers dropout probability. Default: 0.0 + unet_norm : bool + Whether to use normalized U-Net as denoiser or not. Default: False. + **kwargs: Additional keyword arguments. + Can be `model_name`. + """ + # pylint: disable=too-many-locals + super().__init__() + for extra_key in kwargs: + if extra_key != "model_name": + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + self.num_steps = num_steps + self.num_steps_dc_gd = num_steps_dc_gd + + self.no_parameter_sharing = no_parameter_sharing + + self.denoiser_blocks = nn.ModuleList() + for _ in range(num_steps if self.no_parameter_sharing else 1): + self.denoiser_blocks.append( + (UnetModel3d if not unet_norm else NormUnetModel3d)( + in_channels=COMPLEX_SIZE * 3, + out_channels=COMPLEX_SIZE, + num_filters=unet_num_filters, + num_pool_layers=unet_num_pool_layers, + dropout_probability=unet_dropout, + ) + ) + + self.initializer = LagrangeMultipliersInitializer3D( + in_channels=COMPLEX_SIZE, + out_channels=COMPLEX_SIZE, + channels=initializer_channels, + dilations=initializer_dilations, + multiscale_depth=initializer_multiscale, + activation=initializer_activation, + ) + + self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True)) + nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0) + + self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True)) + nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0) + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + if image_init not in ["sense", "zero_filled"]: + raise ValueError(f"Unknown image_initialization. Expected 'sense' or 'zero_filled'. " f"Got {image_init}.") + + self.image_init = image_init + + if not (auxiliary_steps == -1 or 0 < auxiliary_steps <= num_steps): + raise ValueError( + f"Number of auxiliary steps should be -1 to use all steps or a positive" + f" integer <= than `num_steps`. Received {auxiliary_steps}." + ) + if auxiliary_steps == -1: + self.auxiliary_steps = list(range(num_steps)) + else: + self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps)) + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (3, 4) + + def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: torch.Tensor, + sampling_mask: torch.Tensor, + ) -> list[torch.Tensor]: + """Computes forward pass of :class:`VSharpNet3D`. + + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, slice, height, width, complex=2). + sensitivity_map : torch.Tensor + Sensitivity map of shape (N, coil, slice, height, width, complex=2). + sampling_mask : torch.Tensor + Sampling mask of shape (N, 1, 1 or slice, height, width, 1). + + Returns + ------- + out : list of torch.Tensors + List of output images each of shape (N, slice, height, width, complex=2). + """ + out = [] + if self.image_init == "sense": + x = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + else: + x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim) + + z = x.clone() + + u = self.initializer(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1) + + for admm_step in range(self.num_steps): + z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0]( + torch.cat( + [z, x, u / self.rho[admm_step]], + dim=self._complex_dim, + ).permute(0, 4, 1, 2, 3) + ).permute(0, 2, 3, 4, 1) + + for dc_gd_step in range(self.num_steps_dc_gd): + dc = apply_mask( + self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims) + - masked_kspace, + sampling_mask, + return_mask=False, + ) + dc = self.backward_operator(dc, dim=self._spatial_dims) + dc = reduce_operator(dc, sensitivity_map, self._coil_dim) + + x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u) + + if admm_step in self.auxiliary_steps: + out.append(x) + + u = u + self.rho[admm_step] * (x - z) + + return out diff --git a/direct/nn/vsharp/vsharp_engine.py b/direct/nn/vsharp/vsharp_engine.py index 06dc2cd88..17037546d 100644 --- a/direct/nn/vsharp/vsharp_engine.py +++ b/direct/nn/vsharp/vsharp_engine.py @@ -18,6 +18,134 @@ from direct.utils import detach_dict, dict_to_device +class VSharpNet3DEngine(MRIModelEngine): + """VSharpNet Engine.""" + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`VSharpNetEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable, optional + The forward operator. Default: None. + backward_operator: Callable, optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._spatial_dims = (3, 4) + + def _do_iteration( + self, + data: Dict[str, Any], + loss_fns: Optional[Dict[str, Callable]] = None, + regularizer_fns: Optional[Dict[str, Callable]] = None, + ) -> DoIterationOutput: + """Performs forward method and calculates loss functions. + + Parameters + ---------- + data : Dict[str, Any] + Data containing keys with values tensors such as k-space, image, sensitivity map, etc. + loss_fns : Optional[Dict[str, Callable]] + Callable loss functions. + regularizer_fns : Optional[Dict[str, Callable]] + Callable regularization functions. + + Returns + ------- + DoIterationOutput + Contains outputs. + """ + + # loss_fns can be None, e.g. during validation + if loss_fns is None: + loss_fns = {} + + data = dict_to_device(data, self.device) + + output_image: TensorOrNone + output_kspace: TensorOrNone + + with autocast(enabled=self.mixed_precision): + output_images, output_kspace = self.forward_function(data) + output_images = [T.modulus_if_complex(_, complex_axis=self._complex_dim) for _ in output_images] + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + + auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0]) + for i, output_image in enumerate(output_images): + loss_dict = self.compute_loss_on_data( + loss_dict, loss_fns, data, output_image, None, auxiliary_loss_weights[i] + ) + # Compute loss on k-space + loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, None, output_kspace) + + loss = sum(loss_dict.values()) # type: ignore + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dict = detach_dict(loss_dict) # Detach dict, only used for logging. + + output_image = output_images[-1] + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict}, + ) + + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) + + output_images = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + ) # shape (batch, height, width, complex[=2]) + + output_image = output_images[-1] + output_kspace = data["masked_kspace"] + T.apply_mask( + T.apply_padding( + self.forward_operator( + T.expand_operator(output_image, data["sensitivity_map"], dim=self._coil_dim), + dim=self._spatial_dims, + ), + padding=data.get("padding", None), + ), + ~data["sampling_mask"], + return_mask=False, + ) + + return output_images, output_kspace + + class VSharpNetEngine(MRIModelEngine): """VSharpNet 2D Model Engine.""" diff --git a/tests/tests_nn/test_vsharp.py b/tests/tests_nn/test_vsharp.py index 2d4d10a23..2ce635cb3 100644 --- a/tests/tests_nn/test_vsharp.py +++ b/tests/tests_nn/test_vsharp.py @@ -7,8 +7,8 @@ from direct.data.transforms import fft2, ifft2 from direct.nn.get_nn_model_config import ModelName -from direct.nn.types import InitType -from direct.nn.vsharp.vsharp import VSharpNet +from direct.nn.types import ActivationType, InitType +from direct.nn.vsharp.vsharp import VSharpNet, VSharpNet3D def create_input(shape): @@ -37,7 +37,7 @@ def create_input(shape): ], ) @pytest.mark.parametrize("aux_steps", [-1, 1, -2, 4]) -def test_varsplitnet( +def test_vsharpnet( shape, num_steps, num_steps_dc_gd, @@ -94,3 +94,57 @@ def test_varsplitnet( for i in range(len(out)): assert list(out[i].shape) == [shape[0]] + shape[2:] + [2] + + +@pytest.mark.parametrize("shape", [[1, 3, 10, 16, 16]]) +@pytest.mark.parametrize("num_steps", [3]) +@pytest.mark.parametrize("num_steps_dc_gd", [2]) +@pytest.mark.parametrize("image_init", [InitType.SENSE, InitType.ZERO_FILLED]) +@pytest.mark.parametrize( + "image_model_kwargs", + [ + {"unet_num_filters": 4, "unet_num_pool_layers": 2}, + ], +) +@pytest.mark.parametrize( + "initializer_channels, initializer_dilations, initializer_multiscale, initializer_activation", + [ + [(8, 8, 8, 16), (1, 1, 2, 4), 2, ActivationType.RELU], + [(8, 8, 16), (1, 1, 4), 1, ActivationType.LEAKY_RELU], + ], +) +@pytest.mark.parametrize("aux_steps", [-1, 1]) +def test_vsharpnet3d( + shape, + num_steps, + num_steps_dc_gd, + image_init, + image_model_kwargs, + initializer_channels, + initializer_dilations, + initializer_multiscale, + initializer_activation, + aux_steps, +): + model = VSharpNet3D( + fft2, + ifft2, + num_steps=num_steps, + num_steps_dc_gd=num_steps_dc_gd, + image_init=image_init, + no_parameter_sharing=False, + initializer_channels=initializer_channels, + initializer_dilations=initializer_dilations, + initializer_multiscale=initializer_multiscale, + initializer_activation=initializer_activation, + auxiliary_steps=aux_steps, + **image_model_kwargs, + ).cpu() + + kspace = create_input(shape + [2]).cpu() + mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu() + sens = create_input(shape + [2]).cpu() + out = model(kspace, sens, mask) + + for i in range(len(out)): + assert list(out[i].shape) == [shape[0]] + shape[2:] + [2] diff --git a/tests/tests_nn/test_vsharp_engine.py b/tests/tests_nn/test_vsharp_engine.py index 1bb30ce82..903294e60 100644 --- a/tests/tests_nn/test_vsharp_engine.py +++ b/tests/tests_nn/test_vsharp_engine.py @@ -10,9 +10,9 @@ from direct.config.defaults import DefaultConfig, FunctionConfig, LossConfig, TrainingConfig, ValidationConfig from direct.data.transforms import fft2, ifft2 -from direct.nn.vsharp.config import VSharpNetConfig -from direct.nn.vsharp.vsharp import VSharpNet -from direct.nn.vsharp.vsharp_engine import VSharpNetEngine +from direct.nn.vsharp.config import VSharpNet3DConfig, VSharpNetConfig +from direct.nn.vsharp.vsharp import VSharpNet, VSharpNet3D +from direct.nn.vsharp.vsharp_engine import VSharpNet3DEngine, VSharpNetEngine def create_sample(shape, **kwargs): @@ -81,3 +81,73 @@ def test_vsharpnet_engine(shape, loss_fns, num_steps, num_steps_dc_gd, num_filte loss_fns = engine.build_loss() out = engine._do_iteration(data, loss_fns) out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) + + +@pytest.mark.parametrize( + "shape", + [(2, 3, 4, 10, 16, 2), (1, 11, 8, 12, 16, 2)], +) +@pytest.mark.parametrize( + "loss_fns", + [ + [ + "l1_loss", + "snr_loss", + "hfen_l1_loss", + "hfen_l2_loss", + "hfen_l1_norm_loss", + "hfen_l2_norm_loss", + "kspace_nmse_loss", + "kspace_nmae_loss", + "ssim_3d_loss", + ] + ], +) +@pytest.mark.parametrize( + "num_steps, num_steps_dc_gd, num_filters, num_pool_layers", + [[4, 2, 10, 2]], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_vsharpnet3d_engine(shape, loss_fns, num_steps, num_steps_dc_gd, num_filters, num_pool_layers, normalized): + # Operators + forward_operator = functools.partial(fft2, centered=True) + backward_operator = functools.partial(ifft2, centered=True) + # Configs + loss_config = LossConfig(losses=[FunctionConfig(loss) for loss in loss_fns]) + training_config = TrainingConfig(loss=loss_config) + validation_config = ValidationConfig(crop=None) + model_config = VSharpNet3DConfig( + num_steps=num_steps, + num_steps_dc_gd=num_steps_dc_gd, + unet_num_filters=num_filters, + unet_num_pool_layers=num_pool_layers, + auxiliary_steps=-1, + ) + config = DefaultConfig(training=training_config, validation=validation_config, model=model_config) + # Models + model = VSharpNet3D( + forward_operator, + backward_operator, + num_steps=model_config.num_steps, + num_steps_dc_gd=model_config.num_steps_dc_gd, + unet_num_filters=model_config.unet_num_filters, + unet_num_pool_layers=model_config.unet_num_pool_layers, + auxiliary_steps=model_config.auxiliary_steps, + ) + sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) + # Define engine + engine = VSharpNet3DEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 3 + # Test _do_iteration function with a single data batch + data = create_sample( + shape, + sampling_mask=torch.from_numpy(np.random.randn(1, 1, 1, shape[3], shape[4], 1)).bool(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3], shape[4])).float(), + scaling_factor=torch.ones(shape[0]), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) From 1dbe19639731485b12ddcffc77f058532d7f5f15 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 28 Mar 2024 11:59:42 +0100 Subject: [PATCH 04/10] Typing fixes --- direct/nn/vsharp/vsharp_engine.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/direct/nn/vsharp/vsharp_engine.py b/direct/nn/vsharp/vsharp_engine.py index 17037546d..fe2bac542 100644 --- a/direct/nn/vsharp/vsharp_engine.py +++ b/direct/nn/vsharp/vsharp_engine.py @@ -26,8 +26,8 @@ def __init__( cfg: BaseConfig, model: nn.Module, device: str, - forward_operator: Optional[Callable] = None, - backward_operator: Optional[Callable] = None, + forward_operator: Optional[callable] = None, + backward_operator: Optional[callable] = None, mixed_precision: bool = False, **models: nn.Module, ): @@ -41,9 +41,9 @@ def __init__( Model. device: str Device. Can be "cuda:{idx}" or "cpu". - forward_operator: Callable, optional + forward_operator: callable, optional The forward operator. Default: None. - backward_operator: Callable, optional + backward_operator: callable, optional The backward operator. Default: None. mixed_precision: bool Use mixed precision. Default: False. @@ -64,20 +64,20 @@ def __init__( def _do_iteration( self, - data: Dict[str, Any], - loss_fns: Optional[Dict[str, Callable]] = None, - regularizer_fns: Optional[Dict[str, Callable]] = None, + data: dict[str, Any], + loss_fns: Optional[dict[str, callable]] = None, + regularizer_fns: Optional[dict[str, callable]] = None, ) -> DoIterationOutput: """Performs forward method and calculates loss functions. Parameters ---------- - data : Dict[str, Any] + data : dict[str, Any] Data containing keys with values tensors such as k-space, image, sensitivity map, etc. - loss_fns : Optional[Dict[str, Callable]] - Callable loss functions. - regularizer_fns : Optional[Dict[str, Callable]] - Callable regularization functions. + loss_fns : Optional[dict[str, callable]] + callable loss functions. + regularizer_fns : Optional[dict[str, callable]] + callable regularization functions. Returns ------- @@ -121,7 +121,7 @@ def _do_iteration( data_dict={**loss_dict}, ) - def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: + def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, None]: data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) output_images = self.model( @@ -247,7 +247,7 @@ def _do_iteration( data_dict={**loss_dict}, ) - def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, None]: + def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) output_images = self.model( From d92d3c29d72de12613cfbdda0fa3e13790caa9b8 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 28 Mar 2024 16:13:29 +0100 Subject: [PATCH 05/10] Init type for unet --- direct/nn/unet/config.py | 5 +++-- projects/calgary_campinas/configs/base_unet.yaml | 2 +- .../calgary_campinas/configs/comparisons/base_unet.yaml | 2 +- .../configs_inference/10x/comparisons/base_unet.yaml | 2 +- .../configs_inference/5x/comparisons/base_unet.yaml | 2 +- .../fastmri/AXT1_brain/configs/base_unet.yaml | 2 +- .../fastmri/AXT1_brain/configs_inference/4x/base_unet.yaml | 2 +- .../fastmri/AXT1_brain/configs_inference/8x/base_unet.yaml | 2 +- projects/toy/shepp_logan/base_unet.yaml | 2 +- projects/vSHARP/fastmri_prostate/configs/base_unet.yaml | 2 +- 10 files changed, 12 insertions(+), 11 deletions(-) diff --git a/direct/nn/unet/config.py b/direct/nn/unet/config.py index 7eca4c354..f87d874c2 100644 --- a/direct/nn/unet/config.py +++ b/direct/nn/unet/config.py @@ -1,8 +1,9 @@ -# coding=utf-8 # Copyright (c) DIRECT Contributors + from dataclasses import dataclass from direct.config.defaults import ModelConfig +from direct.nn.types import InitType @dataclass @@ -30,7 +31,7 @@ class Unet2dConfig(ModelConfig): dropout_probability: float = 0.0 skip_connection: bool = False normalized: bool = False - image_initialization: str = "zero_filled" + image_initialization: InitType = InitType.ZERO_FILLED @dataclass diff --git a/projects/calgary_campinas/configs/base_unet.yaml b/projects/calgary_campinas/configs/base_unet.yaml index 39f635911..248830d81 100644 --- a/projects/calgary_campinas/configs/base_unet.yaml +++ b/projects/calgary_campinas/configs/base_unet.yaml @@ -90,7 +90,7 @@ validation: model: model_name: unet.unet_2d.Unet2d num_filters: 64 - image_initialization: sense + image_initialization: SENSE additional_models: sensitivity_model: model_name: unet.unet_2d.UnetModel2d diff --git a/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs/comparisons/base_unet.yaml b/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs/comparisons/base_unet.yaml index d048f56dd..e1fa60a30 100644 --- a/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs/comparisons/base_unet.yaml +++ b/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs/comparisons/base_unet.yaml @@ -96,7 +96,7 @@ validation: model: model_name: unet.unet_2d.Unet2d num_filters: 64 - image_initialization: sense + image_initialization: SENSE additional_models: sensitivity_model: model_name: unet.unet_2d.UnetModel2d diff --git a/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs_inference/10x/comparisons/base_unet.yaml b/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs_inference/10x/comparisons/base_unet.yaml index 190beac79..dd9a18914 100644 --- a/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs_inference/10x/comparisons/base_unet.yaml +++ b/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs_inference/10x/comparisons/base_unet.yaml @@ -96,7 +96,7 @@ validation: model: model_name: unet.unet_2d.Unet2d num_filters: 64 - image_initialization: sense + image_initialization: SENSE additional_models: sensitivity_model: model_name: unet.unet_2d.UnetModel2d diff --git a/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs_inference/5x/comparisons/base_unet.yaml b/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs_inference/5x/comparisons/base_unet.yaml index 118f00b5c..39a528e50 100644 --- a/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs_inference/5x/comparisons/base_unet.yaml +++ b/projects/cvpr2022_recurrentvarnet/calgary_campinas/configs_inference/5x/comparisons/base_unet.yaml @@ -96,7 +96,7 @@ validation: model: model_name: unet.unet_2d.Unet2d num_filters: 64 - image_initialization: sense + image_initialization: SENSE additional_models: sensitivity_model: model_name: unet.unet_2d.UnetModel2d diff --git a/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs/base_unet.yaml b/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs/base_unet.yaml index 3e768593a..b6c008bda 100644 --- a/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs/base_unet.yaml +++ b/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs/base_unet.yaml @@ -92,7 +92,7 @@ validation: model: model_name: unet.unet_2d.Unet2d num_filters: 32 - image_initialization: sense + image_initialization: SENSE additional_models: sensitivity_model: model_name: unet.unet_2d.UnetModel2d diff --git a/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs_inference/4x/base_unet.yaml b/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs_inference/4x/base_unet.yaml index 5cfe9bc4e..bed2d7be4 100644 --- a/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs_inference/4x/base_unet.yaml +++ b/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs_inference/4x/base_unet.yaml @@ -92,7 +92,7 @@ validation: model: model_name: unet.unet_2d.Unet2d num_filters: 32 - image_initialization: sense + image_initialization: SENSE additional_models: sensitivity_model: model_name: unet.unet_2d.UnetModel2d diff --git a/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs_inference/8x/base_unet.yaml b/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs_inference/8x/base_unet.yaml index 1d4d3f24f..d4aaa9a55 100644 --- a/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs_inference/8x/base_unet.yaml +++ b/projects/cvpr2022_recurrentvarnet/fastmri/AXT1_brain/configs_inference/8x/base_unet.yaml @@ -92,7 +92,7 @@ validation: model: model_name: unet.unet_2d.Unet2d num_filters: 32 - image_initialization: sense + image_initialization: SENSE additional_models: sensitivity_model: model_name: unet.unet_2d.UnetModel2d diff --git a/projects/toy/shepp_logan/base_unet.yaml b/projects/toy/shepp_logan/base_unet.yaml index 7713f6821..42de128cc 100644 --- a/projects/toy/shepp_logan/base_unet.yaml +++ b/projects/toy/shepp_logan/base_unet.yaml @@ -94,7 +94,7 @@ validation: model: model_name: unet.unet_2d.Unet2d num_filters: 64 - image_initialization: sense + image_initialization: SENSE additional_models: sensitivity_model: model_name: unet.unet_2d.UnetModel2d diff --git a/projects/vSHARP/fastmri_prostate/configs/base_unet.yaml b/projects/vSHARP/fastmri_prostate/configs/base_unet.yaml index 2640d747e..f795b32ef 100644 --- a/projects/vSHARP/fastmri_prostate/configs/base_unet.yaml +++ b/projects/vSHARP/fastmri_prostate/configs/base_unet.yaml @@ -203,7 +203,7 @@ additional_models: model: model_name: unet.unet_2d.Unet2d num_filters: 64 - image_initialization: sense + image_initialization: SENSE logging: tensorboard: num_images: 4 From 51f7ed335ce70bd64ebaad9a0c331db9af91ea27 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 28 Mar 2024 16:14:06 +0100 Subject: [PATCH 06/10] Minor documentation fixes --- direct/nn/unet/unet_2d.py | 15 ++++++++------- direct/nn/unet/unet_3d.py | 15 +++++++++------ direct/nn/vsharp/vsharp.py | 35 +++++++++++++++++++++++++++------- tests/tests_nn/test_unet_3d.py | 2 +- 4 files changed, 46 insertions(+), 21 deletions(-) diff --git a/direct/nn/unet/unet_2d.py b/direct/nn/unet/unet_2d.py index 5fd9853c8..32b9b9de1 100644 --- a/direct/nn/unet/unet_2d.py +++ b/direct/nn/unet/unet_2d.py @@ -10,6 +10,7 @@ from torch.nn import functional as F from direct.data import transforms as T +from direct.nn.types import InitType class ConvBlock(nn.Module): @@ -334,7 +335,7 @@ def __init__( dropout_probability: float, skip_connection: bool = False, normalized: bool = False, - image_initialization: str = "zero_filled", + image_initialization: InitType = InitType.ZERO_FILLED, **kwargs, ): """Inits :class:`Unet2d`. @@ -355,8 +356,8 @@ def __init__( If True, skip connection is used for the output. Default: False. normalized: bool If True, Normalized Unet is used. Default: False. - image_initialization: str - Type of image initialization. Default: "zero-filled". + image_initialization: InitType + Type of image initialization. Default: InitType.ZERO_FILLED. kwargs: dict """ super().__init__() @@ -437,18 +438,18 @@ def forward( output: torch.Tensor Output image of shape (N, height, width, complex=2). """ - if self.image_initialization == "sense": + if self.image_initialization == InitType.SENSE: if sensitivity_map is None: - raise ValueError("Expected sensitivity_map not to be None with 'sense' image_initialization.") + raise ValueError("Expected sensitivity_map not to be None with InitType.SENSE image_initialization.") input_image = self.compute_sense_init( kspace=masked_kspace, sensitivity_map=sensitivity_map, ) - elif self.image_initialization == "zero_filled": + elif self.image_initialization == InitType.SENSE: input_image = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim) else: raise ValueError( - f"Unknown image_initialization. Expected `sense` or `zero_filled`. " + f"Unknown image_initialization. Expected InitType.ZERO_FILLED or InitType.SENSE. " f"Got {self.image_initialization}." ) diff --git a/direct/nn/unet/unet_3d.py b/direct/nn/unet/unet_3d.py index 0fd9cd3f7..513ab6069 100644 --- a/direct/nn/unet/unet_3d.py +++ b/direct/nn/unet/unet_3d.py @@ -14,7 +14,7 @@ class ConvBlock3D(nn.Module): """3D U-Net convolutional block.""" - def __init__(self, in_channels: int, out_channels: int, dropout_probability: float): + def __init__(self, in_channels: int, out_channels: int, dropout_probability: float) -> None: """Inits :class:`ConvBlock3D`. Parameters @@ -61,7 +61,7 @@ def forward(self, input_data: torch.Tensor) -> torch.Tensor: class TransposeConvBlock3D(nn.Module): """3D U-Net Transpose Convolutional Block.""" - def __init__(self, in_channels: int, out_channels: int): + def __init__(self, in_channels: int, out_channels: int) -> None: """Inits :class:`TransposeConvBlock3D`. Parameters @@ -101,7 +101,7 @@ class UnetModel3d(nn.Module): """PyTorch implementation of a 3D U-Net model. This class defines a 3D U-Net architecture consisting of down-sampling and up-sampling layers with 3D convolutional - blocks. This is an extension of Unet2dModel, but for volumes. + blocks. This is an extension to 3D volumes of :class:`direct.nn.unet.unet_2d.UnetModel2d`. """ def __init__( @@ -111,7 +111,7 @@ def __init__( num_filters: int, num_pool_layers: int, dropout_probability: float, - ): + ) -> None: """Inits :class:`UnetModel3d`. Parameters @@ -212,7 +212,10 @@ def forward(self, input_data: torch.Tensor) -> torch.Tensor: class NormUnetModel3d(nn.Module): - """Implementation of a Normalized U-Net model for 3D data.""" + """Implementation of a Normalized U-Net model for 3D data. + + This is an extension to 3D volumes of :class:`direct.nn.unet.unet_2d.NormUnetModel2d`. + """ def __init__( self, @@ -222,7 +225,7 @@ def __init__( num_pool_layers: int, dropout_probability: float, norm_groups: int = 2, - ): + ) -> None: """Inits :class:`NormUnetModel3D`. Parameters diff --git a/direct/nn/vsharp/vsharp.py b/direct/nn/vsharp/vsharp.py index 3a6d95921..d3cfce2cd 100644 --- a/direct/nn/vsharp/vsharp.py +++ b/direct/nn/vsharp/vsharp.py @@ -25,7 +25,19 @@ class LagrangeMultipliersInitializer(nn.Module): - """A convolutional neural network model that initializers the Lagrange multiplier of the vSHARPNet.""" + """A convolutional neural network model that initializers the Lagrange multiplier of the :class:`vSHARPNet` [1]_. + + More specifically, it produces an initial value for the Lagrange Multiplier based on the zero-filled image: + + .. math:: + + u^0 = \mathcal{G}_{\psi}(x^0). + + References + ---------- + .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction + of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954. + """ def __init__( self, @@ -140,10 +152,8 @@ class VSharpNet(nn.Module): References ---------- - .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954. - """ def __init__( @@ -282,7 +292,7 @@ def forward( List of output images of shape (N, height, width, complex=2). """ out = [] - if self.image_init == "sense": + if self.image_init == InitType.SENSE: x = reduce_operator( coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), sensitivity_map=sensitivity_map, @@ -324,7 +334,10 @@ def forward( class LagrangeMultipliersInitializer3D(torch.nn.Module): - """A convolutional neural network model that initializes the Lagrange multiplier of :class:`VSharpNet3D`.""" + """A convolutional neural network model that initializes the Lagrange multiplier of :class:`VSharpNet3D`. + + This is an extension to 3D data of :class:`LagrangeMultipliersInitializer`. + """ def __init__( self, @@ -400,7 +413,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class VSharpNet3D(nn.Module): - """VharpNet 3D version using 3D U-Nets as denoisers.""" + """VharpNet 3D version using 3D U-Nets as denoisers. + + This is an extension to 3D of :class:`VSharpNet`. For the original paper refer to [1]_. + + References + ---------- + .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction + of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954. + """ def __init__( self, @@ -543,7 +564,7 @@ def forward( List of output images each of shape (N, slice, height, width, complex=2). """ out = [] - if self.image_init == "sense": + if self.image_init == InitType.SENSE: x = reduce_operator( coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), sensitivity_map=sensitivity_map, diff --git a/tests/tests_nn/test_unet_3d.py b/tests/tests_nn/test_unet_3d.py index 834a81343..1b27a93da 100644 --- a/tests/tests_nn/test_unet_3d.py +++ b/tests/tests_nn/test_unet_3d.py @@ -1,6 +1,6 @@ -# coding=utf-8 # Copyright (c) DIRECT Contributors +"""Tests for direct.nn.unet.unet_3d module.""" import numpy as np import pytest From e2654a16aade2b7bb6064e4c864b6a20a15919aa Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Tue, 2 Apr 2024 10:41:13 +0200 Subject: [PATCH 07/10] Unet Image init minor issue fix --- direct/nn/unet/unet_2d.py | 2 +- tests/tests_nn/test_unet_2d.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/direct/nn/unet/unet_2d.py b/direct/nn/unet/unet_2d.py index 32b9b9de1..dcd0958cf 100644 --- a/direct/nn/unet/unet_2d.py +++ b/direct/nn/unet/unet_2d.py @@ -445,7 +445,7 @@ def forward( kspace=masked_kspace, sensitivity_map=sensitivity_map, ) - elif self.image_initialization == InitType.SENSE: + elif self.image_initialization == InitType.ZERO_FILLED: input_image = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim) else: raise ValueError( diff --git a/tests/tests_nn/test_unet_2d.py b/tests/tests_nn/test_unet_2d.py index 38652282a..c3944544d 100644 --- a/tests/tests_nn/test_unet_2d.py +++ b/tests/tests_nn/test_unet_2d.py @@ -1,13 +1,14 @@ -# coding=utf-8 # Copyright (c) DIRECT Contributors +"""Tests for `direct.nn.unet.unet_2d.Unet2d` model.""" import numpy as np import pytest import torch from direct.data.transforms import fft2, ifft2 -from direct.nn.unet.unet_2d import NormUnetModel2d, Unet2d +from direct.nn.types import InitType +from direct.nn.unet.unet_2d import Unet2d def create_input(shape): @@ -42,7 +43,11 @@ def create_input(shape): "normalized", [True, False], ) -def test_unet_2d(shape, num_filters, num_pool_layers, skip, normalized): +@pytest.mark.parametrize( + "image_init", + [InitType.SENSE, InitType.ZERO_FILLED], +) +def test_unet_2d(shape, num_filters, num_pool_layers, skip, normalized, image_init): model = Unet2d( fft2, ifft2, @@ -50,6 +55,7 @@ def test_unet_2d(shape, num_filters, num_pool_layers, skip, normalized): num_pool_layers=num_pool_layers, skip_connection=skip, normalized=normalized, + image_initialization=image_init, dropout_probability=0.05, ).cpu() From cda0ead7163913b7a0407b686f57d3bbdbb5111d Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Tue, 2 Apr 2024 10:42:51 +0200 Subject: [PATCH 08/10] Update Makefile to remove ipynb artifacts --- Makefile | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 8938e976d..a6c9f30bd 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,7 @@ BROWSER := python -c "$$BROWSER_PYSCRIPT" help: @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) -clean: clean-build clean-pyc clean-cpy clean-test clean-docs ## remove all build, test, coverage, docs and Python and cython artifacts +clean: clean-build clean-pyc clean-cpy clean-ipynb clean-test clean-docs ## remove all build, test, coverage, docs and Python and cython artifacts clean-build: ## remove build artifacts rm -fr build/ @@ -46,6 +46,9 @@ clean-cpy: ## remove cython file artifacts find . -name '*.cpp' -exec rm -f {} + find . -name '*.so' -exec rm -f {} + +clean-ipynb: ## remove ipynb artifacts + find . -name '.ipynb_checkpoints' -exec rm -rf {} + + clean-test: ## remove test and coverage artifacts rm -fr .tox/ rm -f .coverage From 1be817e8bd58a7cdaf62a8a8fc21195bcdb2694c Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Tue, 2 Apr 2024 11:42:34 +0200 Subject: [PATCH 09/10] Update for new version of xsdata --- tests/tests_data/test_datasets.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/tests_data/test_datasets.py b/tests/tests_data/test_datasets.py index 50a3f5e8e..416b2b0e3 100644 --- a/tests/tests_data/test_datasets.py +++ b/tests/tests_data/test_datasets.py @@ -1,7 +1,6 @@ -# coding=utf-8 # Copyright (c) DIRECT Contributors -"""Tests for the direct.data.datasets module.""" +"""Tests for the `direct.data.datasets` module.""" import pathlib import tempfile @@ -10,6 +9,7 @@ import ismrmrd import numpy as np import pytest +from xsdata.formats.dataclass.serializers import XmlSerializer from direct.data.datasets import ( CalgaryCampinasDataset, @@ -83,7 +83,10 @@ def create_fastmri_h5file(filename, shape, recon_shape): h5file = h5py.File(filename, "w") h5file.create_dataset("kspace", data=kspace) h5file.create_dataset("reconstruction_rss", data=rss) - h5file.create_dataset("ismrmrd_header", data=header.toXML()) + + # Serializing 'header' object to XML string. + xml_string = XmlSerializer().render(header) + h5file.create_dataset("ismrmrd_header", data=xml_string) h5file.attrs["norm"] = np.linalg.norm(kspace) h5file.attrs["max"] = np.abs(kspace).max() From 725a1198224d4014a50dbcb2ecffd0be748a2f70 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Tue, 2 Apr 2024 12:23:44 +0200 Subject: [PATCH 10/10] callable -> Callable --- direct/nn/vsharp/vsharp.py | 18 ++++++++++-------- direct/nn/vsharp/vsharp_engine.py | 18 +++++++++--------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/direct/nn/vsharp/vsharp.py b/direct/nn/vsharp/vsharp.py index d3cfce2cd..6108bff13 100644 --- a/direct/nn/vsharp/vsharp.py +++ b/direct/nn/vsharp/vsharp.py @@ -12,6 +12,8 @@ from __future__ import annotations +from typing import Any, Callable + import numpy as np import torch import torch.nn.functional as F @@ -158,8 +160,8 @@ class VSharpNet(nn.Module): def __init__( self, - forward_operator: callable, - backward_operator: callable, + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], num_steps: int, num_steps_dc_gd: int, image_init: InitType = InitType.SENSE, @@ -176,9 +178,9 @@ def __init__( Parameters ---------- - forward_operator : callable + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] Forward operator function. - backward_operator : callable + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] Backward operator function. num_steps : int Number of steps in the ADMM algorithm. @@ -425,8 +427,8 @@ class VSharpNet3D(nn.Module): def __init__( self, - forward_operator: callable, - backward_operator: callable, + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], num_steps: int, num_steps_dc_gd: int, image_init: InitType = InitType.SENSE, @@ -446,9 +448,9 @@ def __init__( Parameters ---------- - forward_operator : callable + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] Forward operator function. - backward_operator : callable + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] Backward operator function. num_steps : int Number of steps in the ADMM algorithm. diff --git a/direct/nn/vsharp/vsharp_engine.py b/direct/nn/vsharp/vsharp_engine.py index fe2bac542..0ec3f8d4f 100644 --- a/direct/nn/vsharp/vsharp_engine.py +++ b/direct/nn/vsharp/vsharp_engine.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any, Callable, Optional import torch from torch import nn @@ -26,8 +26,8 @@ def __init__( cfg: BaseConfig, model: nn.Module, device: str, - forward_operator: Optional[callable] = None, - backward_operator: Optional[callable] = None, + forward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + backward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, mixed_precision: bool = False, **models: nn.Module, ): @@ -41,9 +41,9 @@ def __init__( Model. device: str Device. Can be "cuda:{idx}" or "cpu". - forward_operator: callable, optional + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The forward operator. Default: None. - backward_operator: callable, optional + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The backward operator. Default: None. mixed_precision: bool Use mixed precision. Default: False. @@ -154,8 +154,8 @@ def __init__( cfg: BaseConfig, model: nn.Module, device: str, - forward_operator: Optional[callable] = None, - backward_operator: Optional[callable] = None, + forward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, + backward_operator: Optional[Callable[[tuple[Any, ...]], torch.Tensor]] = None, mixed_precision: bool = False, **models: nn.Module, ) -> None: @@ -169,9 +169,9 @@ def __init__( Model. device: str Device. Can be "cuda:{idx}" or "cpu". - forward_operator: callable, optional + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The forward operator. Default: None. - backward_operator: callable, optional + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The backward operator. Default: None. mixed_precision: bool Use mixed precision. Default: False.