From 28f28717626978d3266be4505f1ab3d81a3adc6b Mon Sep 17 00:00:00 2001 From: George Date: Thu, 2 Sep 2021 16:09:10 +0200 Subject: [PATCH 1/4] Merged RIM with MRIReconstruction models --- direct/environment.py | 9 +- direct/nn/rim/config.py | 1 + direct/nn/rim/mri_models.py | 334 ------------------------------------ direct/nn/rim/rim.py | 313 +++++++++++++++++++++++++++++---- 4 files changed, 289 insertions(+), 368 deletions(-) delete mode 100644 direct/nn/rim/mri_models.py diff --git a/direct/environment.py b/direct/environment.py index da4e95a6..efebe766 100644 --- a/direct/environment.py +++ b/direct/environment.py @@ -18,7 +18,6 @@ import direct.utils.logging from direct.config.defaults import DefaultConfig, InferenceConfig, TrainingConfig, ValidationConfig -from direct.nn.rim.mri_models import MRIReconstruction from direct.utils import communication, count_parameters, str_to_class logger = logging.getLogger(__name__) @@ -131,9 +130,11 @@ def initialize_models_from_config(cfg, models, forward_operator, backward_operat curr_model_cfg = {kk: vv for kk, vv in v.items() if kk != "model_name"} additional_models[k] = curr_model(**curr_model_cfg) - # MODEL SHOULD LOAD MRI RECONSTRUCTION INSTEAD AND USE A FUNCTOOLS PARTIAL TO PASS THE OPERATORS - # the_real_model = models["model"](**{k: v for k, v in cfg.model.items() if k != "model_name"}) - model = MRIReconstruction(models["model"], forward_operator, backward_operator, 2, **cfg.model).to(device) + model = models["model"]( + forward_operator=forward_operator, + backward_operator=backward_operator, + **{k: v for (k, v) in cfg.model.items()}, + ).to(device) # Log total number of parameters count_parameters({"model": model, **additional_models}) diff --git a/direct/nn/rim/config.py b/direct/nn/rim/config.py index 9c799889..7d9b1f77 100644 --- a/direct/nn/rim/config.py +++ b/direct/nn/rim/config.py @@ -8,6 +8,7 @@ @dataclass class RIMConfig(ModelConfig): + x_channels: int = 2 hidden_channels: int = 16 length: int = 8 depth: int = 2 diff --git a/direct/nn/rim/mri_models.py b/direct/nn/rim/mri_models.py deleted file mode 100644 index 25d05c96..00000000 --- a/direct/nn/rim/mri_models.py +++ /dev/null @@ -1,334 +0,0 @@ -# coding=utf-8 -# Copyright (c) DIRECT Contributors - -from typing import Iterable, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from direct.data import transforms as T - - -class MRILogLikelihood(nn.Module): - def __init__(self, forward_operator, backward_operator): - super().__init__() - - self.forward_operator = forward_operator - self.backward_operator = backward_operator - - # TODO UGLY - self.ndim = 2 - - def forward( - self, - input_image, - masked_kspace, - sensitivity_map, - sampling_mask, - loglikelihood_scaling=None, - ): - r""" - Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils. - $$ \frac{1}{\sigma^2} \sum_{i}^{\text{num coils}} - {S}_i^\{text{H}} \mathcal{F}^{-1} P^T (P \mathcal{F} S_i x_\tau - y_\tau)$$ - for each time step $\tau$ - - Parameters - ---------- - input_image : torch.tensor - Initial or previous iteration of image with complex first - of shape (batch, complex, [slice,] height, width). - masked_kspace : torch.tensor - Masked k-space of shape (batch, coil, [slice,] height, width, complex). - sensitivity_map : torch.tensor - Sensitivity Map of shape (batch, coil, [slice,] height, width, complex). - sampling_mask : torch.tensor - loglikelihood_scaling : float - Multiplier for loglikelihood, for instance for the k-space noise, of shape (1,). - - Returns - ------- - torch.Tensor - """ - if input_image.ndim == 5: - self.ndim = 3 - - input_image = input_image.permute( - (0, 2, 3, 1) if self.ndim == 2 else (0, 2, 3, 4, 1) - ) # shape (batch, [slice,] height, width, complex) - - loglikelihood_scaling = loglikelihood_scaling.reshape( - list(torch.ones(len(sensitivity_map.shape)).int()) - ) # shape (1, 1, 1, [1,] 1, 1) - - # We multiply by the loglikelihood_scaling here to prevent fp16 information loss, - # as this value is typically <<1, and the operators are linear. - - mul = loglikelihood_scaling * T.complex_multiplication( - sensitivity_map, input_image.unsqueeze(1) # (batch, 1, [slice,] height, width, complex) - ) # shape (batch, coil, [slice,] height, width, complex) - - coil_dim = 1 - # TODO(gy): Is if statement needed? Do 3D data pass from here? - spatial_dims = (2, 3) if mul.ndim == 5 else (2, 3, 4) - - mr_forward = torch.where( - sampling_mask == 0, - torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), - self.forward_operator(mul, dim=spatial_dims), - ) # shape (batch, coil, [slice], height, width, complex) - - error = mr_forward - loglikelihood_scaling * torch.where( - sampling_mask == 0, - torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), - masked_kspace, - ) # shape (batch, coil, [slice], height, width, complex) - - mr_backward = self.backward_operator( - error, dim=spatial_dims - ) # shape (batch, coil, [slice], height, width, complex) - - if sensitivity_map is not None: - out = T.complex_multiplication(T.conjugate(sensitivity_map), mr_backward).sum(coil_dim) - else: - out = mr_backward.sum(coil_dim) - # out has shape (batch, complex=2, [slice], height, width) - - out = ( - out.permute(0, 3, 1, 2) if self.ndim == 2 else out.permute(0, 4, 1, 2, 3) - ) # complex first: shape (batch, [slice], height, width, complex=2) - - return out - - -class RIMInit(nn.Module): - def __init__( - self, - x_ch: int, - out_ch: int, - channels: Tuple[int, ...], - dilations: Tuple[int, ...], - depth: int = 2, - multiscale_depth: int = 1, - ): - """ - Learned initializer for RIM, based on multi-scale context aggregation with dilated convolutions, that replaces - zero initializer for the RIM hidden vector. - - Inspired by "Multi-Scale Context Aggregation by Dilated Convolutions" (https://arxiv.org/abs/1511.07122) - - Parameters - ---------- - x_ch : int - Input channels. - out_ch : int - Number of hidden channels in the RIM. - channels : tuple - Channels in the convolutional layers of initializer. Typical it could be e.g. (32, 32, 64, 64). - dilations: tuple - Dilations of the convolutional layers of the initializer. Typically it could be e.g. (1, 1, 2, 4). - depth : int - RIM depth - multiscale_depth : 1 - Number of feature layers to aggregate for the output, if 1, multi-scale context aggregation is disabled. - - """ - super().__init__() - self.conv_blocks = nn.ModuleList() - self.out_blocks = nn.ModuleList() - self.depth = depth - self.multiscale_depth = multiscale_depth - tch = x_ch - for (curr_channels, curr_dilations) in zip(channels, dilations): - block = [ - nn.ReplicationPad2d(curr_dilations), - nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), - ] - tch = curr_channels - self.conv_blocks.append(nn.Sequential(*block)) - tch = np.sum(channels[-multiscale_depth:]) - for idx in range(depth): - block = [nn.Conv2d(tch, out_ch, 1, padding=0)] - self.out_blocks.append(nn.Sequential(*block)) - - def forward(self, x): - - features = [] - for block in self.conv_blocks: - x = F.relu(block(x), inplace=True) - if self.multiscale_depth > 1: - features.append(x) - if self.multiscale_depth > 1: - x = torch.cat(features[-self.multiscale_depth :], dim=1) - output_list = [] - for block in self.out_blocks: - y = F.relu(block(x), inplace=True) - output_list.append(y) - out = torch.stack(output_list, dim=-1) - return out - - -class MRIReconstruction(nn.Module): - def __init__( - self, - rim_model, - forward_operator, - backward_operator, - x_ch, - hidden_channels: int = 16, - length: int = 8, - depth: int = 1, - no_parameter_sharing: bool = False, - instance_norm: bool = False, - dense_connect: bool = False, - replication_padding: bool = True, - image_initialization: str = "zero_filled", - learned_initializer: bool = False, - initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64), - initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4), - initializer_multiscale: int = 1, - **kwargs, - ): - # TODO: Code quality - # BODY: Constructor can be called with **kwargs as much as possible. Is currently already done for some variables. - """ - MRI Reconstruction model based on RIM - """ - super().__init__() - - # Some other keys are possible. Check here if these are actually relevant for MRI Recon. - # TODO: Expand this to a larger class - extra_keys = kwargs.keys() - for extra_key in extra_keys: - if extra_key not in [ - "steps", - "sensitivity_map_model", - "model_name", - "z_reduction_frequency", - "kspace_context", - "scale_loglikelihood", - "whiten_input", # should be passed! - ]: - raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") - - self.model = rim_model( - x_ch, - hidden_channels, - MRILogLikelihood(forward_operator, backward_operator), - length=length, - depth=depth, - no_sharing=no_parameter_sharing, - instance_norm=instance_norm, - dense_connection=dense_connect, - replication_padding=replication_padding, - **kwargs, - ) - self.initializer: Optional[nn.Module] = None - if learned_initializer and initializer_channels is not None and initializer_dilations is not None: - # List is because of a omegaconf bug. - self.initializer = RIMInit( - x_ch, - hidden_channels, - channels=initializer_channels, - dilations=initializer_dilations, - depth=depth, - multiscale_depth=initializer_multiscale, - ) - - self.image_initialization = image_initialization - - self.forward_operator = forward_operator - self.backward_operator = backward_operator - - def compute_sense_init(self, kspace, sensitivity_map, spatial_dims=(2, 3), coil_dim=1): - # kspace is of shape: (batch, coil, [slice,] height, width, complex) - # sensitivity_map is of shape (batch, coil, [slice,] height, width, complex) - - input_image = T.complex_multiplication( - T.conjugate(sensitivity_map), - self.backward_operator(kspace, dim=spatial_dims), - ) # shape (batch, coil, [slice,] height, width, complex=2) - - input_image = input_image.sum(coil_dim) - - # shape (batch, [slice,] height, width, complex=2) - return input_image - - def forward( - self, - input_image, - masked_kspace, - sampling_mask, - sensitivity_map=None, - hidden_state=None, - loglikelihood_scaling=None, - **kwargs, - ): - """ - - Parameters - ---------- - input_image - initial reconstruction by fft or previous rim step - masked_kspace - masked k_space - sensitivity_map - sampling_mask - hidden_state - loglikelihood_scaling - - Returns - ------- - - """ - # Provide input image for the first image - if input_image is None: - if self.image_initialization == "sense": - input_image = self.compute_sense_init( - kspace=masked_kspace, - sensitivity_map=sensitivity_map, - spatial_dims=(3, 4) if masked_kspace.ndim == 6 else (2, 3), - ) - elif self.image_initialization == "input_kspace": - if "initial_kspace" not in kwargs: - raise ValueError( - f"`'initial_kspace` is required as input if initialization is {self.image_initialization}." - ) - input_image = self.compute_sense_init( - kspace=kwargs["initial_kspace"], - sensitivity_map=sensitivity_map, - spatial_dims=(3, 4) if kwargs["initial_kspace"].ndim == 6 else (2, 3), - ) - elif self.image_initialization == "input_image": - if "initial_image" not in kwargs: - raise ValueError( - f"`'initial_image` is required as input if initialization is {self.image_initialization}." - ) - input_image = kwargs["initial_image"] - - elif self.image_initialization == "zero_filled": - coil_dim = 1 - input_image = self.backward_operator(masked_kspace).sum(coil_dim) - else: - raise ValueError( - f"Unknown image_initialization. Expected `sense`, `input_kspace`, `'input_image` or `zero_filled`. " - f"Got {self.image_initialization}." - ) - - # Provide an initialization for the first hidden state. - if (self.initializer is not None) and (hidden_state is None): - hidden_state = self.initializer( - input_image.permute((0, 4, 1, 2, 3) if input_image.ndim == 5 else (0, 3, 1, 2)) - ) # permute to (batch, complex, [slice], height, width), - - return self.model( - input_image=input_image, - masked_kspace=masked_kspace, - sensitivity_map=sensitivity_map, - sampling_mask=sampling_mask, - previous_state=hidden_state, - loglikelihood_scaling=loglikelihood_scaling, - **kwargs, - ) diff --git a/direct/nn/rim/rim.py b/direct/nn/rim/rim.py index baffa13f..27587b95 100644 --- a/direct/nn/rim/rim.py +++ b/direct/nn/rim/rim.py @@ -2,12 +2,14 @@ # Copyright (c) DIRECT Contributors import warnings -from typing import Optional +from typing import Callable, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from direct.data import transforms as T from direct.utils.asserts import assert_positive_integer @@ -24,7 +26,7 @@ def __init__( gru_kernel_size=1, ortho_init: bool = True, instance_norm: bool = False, - dense_connection=0, + dense_connect=0, replication_padding=False, ): super().__init__() @@ -32,7 +34,7 @@ def __init__( self.x_channels = x_channels self.hidden_channels = hidden_channels self.instance_norm = instance_norm - self.dense_connection = dense_connection + self.dense_connect = dense_connect self.repl_pad = replication_padding self.reset_gates = nn.ModuleList([]) @@ -42,7 +44,7 @@ def __init__( # Create convolutional blocks of RIM cell for idx in range(depth + 1): - in_ch = x_channels + 2 if idx == 0 else (1 + min(idx, dense_connection)) * hidden_channels + in_ch = x_channels + 2 if idx == 0 else (1 + min(idx, dense_connect)) * hidden_channels out_ch = hidden_channels if idx < depth else x_channels pad = 0 if replication_padding else (2 if idx == 0 else 1) block = [] @@ -108,12 +110,12 @@ def forward(self, cell_input, previous_state): for idx in range(self.depth): if len(conv_skip) > 0: cell_input = F.relu( - self.conv_blocks[idx](torch.cat([*conv_skip[-self.dense_connection :], cell_input], dim=1)), + self.conv_blocks[idx](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1)), inplace=True, ) else: cell_input = F.relu(self.conv_blocks[idx](cell_input), inplace=True) - if self.dense_connection > 0: + if self.dense_connect > 0: conv_skip.append(cell_input) stacked_inputs = torch.cat([cell_input, previous_state[:, :, :, :, idx]], dim=1) @@ -127,13 +129,176 @@ def forward(self, cell_input, previous_state): new_states.append(cell_input) cell_input = F.relu(cell_input, inplace=False) if len(conv_skip) > 0: - out = self.conv_blocks[self.depth](torch.cat([*conv_skip[-self.dense_connection :], cell_input], dim=1)) + out = self.conv_blocks[self.depth](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1)) else: out = self.conv_blocks[self.depth](cell_input) return out, torch.stack(new_states, dim=-1) +class MRILogLikelihood(nn.Module): + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + ): + super().__init__() + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + # TODO UGLY + self.ndim = 2 + + def forward( + self, + input_image, + masked_kspace, + sensitivity_map, + sampling_mask, + loglikelihood_scaling=None, + ): + r""" + Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils. + $$ \frac{1}{\sigma^2} \sum_{i}^{\text{num coils}} + {S}_i^\{text{H}} \mathcal{F}^{-1} P^T (P \mathcal{F} S_i x_\tau - y_\tau)$$ + for each time step $\tau$ + + Parameters + ---------- + input_image : torch.tensor + Initial or previous iteration of image with complex first + of shape (batch, complex, [slice,] height, width). + masked_kspace : torch.tensor + Masked k-space of shape (batch, coil, [slice,] height, width, complex). + sensitivity_map : torch.tensor + Sensitivity Map of shape (batch, coil, [slice,] height, width, complex). + sampling_mask : torch.tensor + loglikelihood_scaling : torch.tensor + Multiplier for loglikelihood, for instance for the k-space noise, of shape (1,). + + Returns + ------- + torch.Tensor + """ + if input_image.ndim == 5: + self.ndim = 3 + + input_image = input_image.permute( + (0, 2, 3, 1) if self.ndim == 2 else (0, 2, 3, 4, 1) + ) # shape (batch, [slice,] height, width, complex) + + loglikelihood_scaling = loglikelihood_scaling.reshape( + list(torch.ones(len(sensitivity_map.shape)).int()) + ) # shape (1, 1, 1, [1,] 1, 1) + + # We multiply by the loglikelihood_scaling here to prevent fp16 information loss, + # as this value is typically <<1, and the operators are linear. + + mul = loglikelihood_scaling * T.complex_multiplication( + sensitivity_map, input_image.unsqueeze(1) # (batch, 1, [slice,] height, width, complex) + ) # shape (batch, coil, [slice,] height, width, complex) + + coil_dim = 1 + # TODO(gy): Is if statement needed? Do 3D data pass from here? + spatial_dims = (2, 3) if mul.ndim == 5 else (2, 3, 4) + + mr_forward = torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), + self.forward_operator(mul, dim=spatial_dims), + ) # shape (batch, coil, [slice], height, width, complex) + + error = mr_forward - loglikelihood_scaling * torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), + masked_kspace, + ) # shape (batch, coil, [slice], height, width, complex) + + mr_backward = self.backward_operator( + error, dim=spatial_dims + ) # shape (batch, coil, [slice], height, width, complex) + + if sensitivity_map is not None: + out = T.complex_multiplication(T.conjugate(sensitivity_map), mr_backward).sum(coil_dim) + else: + out = mr_backward.sum(coil_dim) + # out has shape (batch, complex=2, [slice], height, width) + + out = ( + out.permute(0, 3, 1, 2) if self.ndim == 2 else out.permute(0, 4, 1, 2, 3) + ) # complex first: shape (batch, [slice], height, width, complex=2) + + return out + + +class RIMInit(nn.Module): + def __init__( + self, + x_ch: int, + out_ch: int, + channels: Tuple[int, ...], + dilations: Tuple[int, ...], + depth: int = 2, + multiscale_depth: int = 1, + ): + """ + Learned initializer for RIM, based on multi-scale context aggregation with dilated convolutions, that replaces + zero initializer for the RIM hidden vector. + + Inspired by "Multi-Scale Context Aggregation by Dilated Convolutions" (https://arxiv.org/abs/1511.07122) + + Parameters + ---------- + x_ch : int + Input channels. + out_ch : int + Number of hidden channels in the RIM. + channels : tuple + Channels in the convolutional layers of initializer. Typical it could be e.g. (32, 32, 64, 64). + dilations: tuple + Dilations of the convolutional layers of the initializer. Typically it could be e.g. (1, 1, 2, 4). + depth : int + RIM depth + multiscale_depth : 1 + Number of feature layers to aggregate for the output, if 1, multi-scale context aggregation is disabled. + + """ + super().__init__() + self.conv_blocks = nn.ModuleList() + self.out_blocks = nn.ModuleList() + self.depth = depth + self.multiscale_depth = multiscale_depth + tch = x_ch + for (curr_channels, curr_dilations) in zip(channels, dilations): + block = [ + nn.ReplicationPad2d(curr_dilations), + nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), + ] + tch = curr_channels + self.conv_blocks.append(nn.Sequential(*block)) + tch = np.sum(channels[-multiscale_depth:]) + for idx in range(depth): + block = [nn.Conv2d(tch, out_ch, 1, padding=0)] + self.out_blocks.append(nn.Sequential(*block)) + + def forward(self, x): + + features = [] + for block in self.conv_blocks: + x = F.relu(block(x), inplace=True) + if self.multiscale_depth > 1: + features.append(x) + if self.multiscale_depth > 1: + x = torch.cat(features[-self.multiscale_depth :], dim=1) + output_list = [] + for block in self.out_blocks: + y = F.relu(block(x), inplace=True) + output_list.append(y) + out = torch.stack(output_list, dim=-1) + return out + + class RIM(nn.Module): """ Recurrent Inference Machine Module as in https://arxiv.org/abs/1706.04008. @@ -141,58 +306,108 @@ class RIM(nn.Module): def __init__( self, - x_channels: int, - num_hidden_channels: int, - grad_likelihood: nn.Module, + forward_operator: Callable, + backward_operator: Callable, + hidden_channels: int, + x_channels: int = 2, length: int = 8, depth: int = 1, - no_sharing: bool = True, + no_parameter_sharing: bool = True, instance_norm: bool = False, - dense_connection: bool = False, + dense_connect: bool = False, skip_connections: bool = True, replication_padding: bool = True, + image_initialization: str = "zero_filled", + learned_initializer: bool = False, + initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64), + initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4), + initializer_multiscale: int = 1, **kwargs, ): super().__init__() - assert_positive_integer(x_channels, num_hidden_channels, length, depth) - # assert_bool(no_sharing, instance_norm, dense_connection, skip_connections, replication_padding) + extra_keys = kwargs.keys() + for extra_key in extra_keys: + if extra_key not in [ + "steps", + "sensitivity_map_model", + "model_name", + "z_reduction_frequency", + "kspace_context", + "scale_loglikelihood", + "whiten_input", # should be passed! + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + + assert_positive_integer(x_channels, hidden_channels, length, depth) + # assert_bool(no_parameter_sharing, instance_norm, dense_connect, skip_connections, replication_padding) + + self.initializer: Optional[nn.Module] = None + if learned_initializer and initializer_channels is not None and initializer_dilations is not None: + # List is because of a omegaconf bug. + self.initializer = RIMInit( + x_channels, + hidden_channels, + channels=initializer_channels, + dilations=initializer_dilations, + depth=depth, + multiscale_depth=initializer_multiscale, + ) + + self.image_initialization = image_initialization + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + self.grad_likelihood = MRILogLikelihood(forward_operator, backward_operator) self.skip_connections = skip_connections self.x_channels = x_channels - self.num_hidden_channels = num_hidden_channels + self.hidden_channels = hidden_channels self.cell_list = nn.ModuleList() - self.no_sharing = no_sharing - for _ in range(length if no_sharing else 1): + self.no_parameter_sharing = no_parameter_sharing + for _ in range(length if no_parameter_sharing else 1): self.cell_list.append( ConvGRUCell( x_channels, - num_hidden_channels, + hidden_channels, depth=depth, instance_norm=instance_norm, - dense_connection=dense_connection, + dense_connect=dense_connect, replication_padding=replication_padding, ) ) + self.length = length - self.grad_likelihood = grad_likelihood self.depth = depth + def compute_sense_init(self, kspace, sensitivity_map, spatial_dims=(2, 3), coil_dim=1): + # kspace is of shape: (batch, coil, [slice,] height, width, complex) + # sensitivity_map is of shape (batch, coil, [slice,] height, width, complex) + + input_image = T.complex_multiplication( + T.conjugate(sensitivity_map), + self.backward_operator(kspace, dim=spatial_dims), + ) # shape (batch, coil, [slice,] height, width, complex=2) + + input_image = input_image.sum(coil_dim) + + # shape (batch, [slice,] height, width, complex=2) + return input_image + def forward( self, input_image: torch.Tensor, masked_kspace: torch.Tensor, - sensitivity_map: torch.Tensor, sampling_mask: torch.Tensor, + sensitivity_map: Optional[torch.Tensor] = None, previous_state: Optional[torch.Tensor] = None, - loglikelihood_scaling: Optional[float] = None, + loglikelihood_scaling: Optional[torch.Tensor] = None, **kwargs, ): - """ - Parameters ---------- input_image : torch.Tensor @@ -205,12 +420,51 @@ def forward( Sampling mask. previous_state : torch.Tensor loglikelihood_scaling : torch.Tensor + Float tensor of shape (1,). Returns ------- torch.Tensor """ + if input_image is None: + if self.image_initialization == "sense": + input_image = self.compute_sense_init( + kspace=masked_kspace, + sensitivity_map=sensitivity_map, + spatial_dims=(3, 4) if masked_kspace.ndim == 6 else (2, 3), + ) + elif self.image_initialization == "input_kspace": + if "initial_kspace" not in kwargs: + raise ValueError( + f"`'initial_kspace` is required as input if initialization is {self.image_initialization}." + ) + input_image = self.compute_sense_init( + kspace=kwargs["initial_kspace"], + sensitivity_map=sensitivity_map, + spatial_dims=(3, 4) if kwargs["initial_kspace"].ndim == 6 else (2, 3), + ) + elif self.image_initialization == "input_image": + if "initial_image" not in kwargs: + raise ValueError( + f"`'initial_image` is required as input if initialization is {self.image_initialization}." + ) + input_image = kwargs["initial_image"] + + elif self.image_initialization == "zero_filled": + coil_dim = 1 + input_image = self.backward_operator(masked_kspace).sum(coil_dim) + else: + raise ValueError( + f"Unknown image_initialization. Expected `sense`, `input_kspace`, `'input_image` or `zero_filled`. " + f"Got {self.image_initialization}." + ) + + # Provide an initialization for the first hidden state. + if (self.initializer is not None) and (previous_state is None): + previous_state = self.initializer( + input_image.permute((0, 4, 1, 2, 3) if input_image.ndim == 5 else (0, 3, 1, 2)) + ) # permute to (batch, complex, [slice], height, width), # TODO: This has to be made contiguous # TODO(gy): Do 3D data pass from here? If not remove if statements below and [slice,] from comments. @@ -226,16 +480,16 @@ def forward( ) # Initialize zero state for RIM - state_size = [batch_size, self.num_hidden_channels] + list(spatial_shape) + [self.depth] + state_size = [batch_size, self.hidden_channels] + list(spatial_shape) + [self.depth] if previous_state is None: - # shape (batch, num_hidden_channels, [slice,] height, width, depth) + # shape (batch, hidden_channels, [slice,] height, width, depth) previous_state = torch.zeros(*state_size, dtype=input_image.dtype).to(input_image.device) cell_outputs = [] intermediate_image = input_image # shape (batch, , complex=2, [slice,] height, width) for cell_idx in range(self.length): - cell = self.cell_list[cell_idx] if self.no_sharing else self.cell_list[0] + cell = self.cell_list[cell_idx] if self.no_parameter_sharing else self.cell_list[0] grad_loglikelihood = self.grad_likelihood( intermediate_image, @@ -256,9 +510,8 @@ def forward( dim=1, ) # shape (batch, , complex=4, [slice,] height, width) - cell_output, previous_state = cell( - cell_input, previous_state - ) # shapes (batch, complex=2, [slice,] height, width), (batch, num_hidden_channels, [slice,] height, width, depth) + cell_output, previous_state = cell(cell_input, previous_state) + # shapes (batch, complex=2, [slice,] height, width), (batch, hidden_channels, [slice,] height, width, depth) if self.skip_connections: # shape (batch, complex=2, [slice,] height, width) From 28b9fd062e881a6a1387637e6f6465d7fc7e6340 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Fri, 3 Sep 2021 13:04:19 +0200 Subject: [PATCH 2/4] Removed unnecessary arg --- direct/nn/rim/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/direct/nn/rim/config.py b/direct/nn/rim/config.py index 7d9b1f77..9c799889 100644 --- a/direct/nn/rim/config.py +++ b/direct/nn/rim/config.py @@ -8,7 +8,6 @@ @dataclass class RIMConfig(ModelConfig): - x_channels: int = 2 hidden_channels: int = 16 length: int = 8 depth: int = 2 From 13a4a5e210eac5e096efafed9984914af0e12190 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Fri, 3 Sep 2021 13:23:41 +0200 Subject: [PATCH 3/4] Local variables made class attrs --- direct/nn/rim/rim.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/direct/nn/rim/rim.py b/direct/nn/rim/rim.py index 27587b95..ec1c9cb0 100644 --- a/direct/nn/rim/rim.py +++ b/direct/nn/rim/rim.py @@ -150,6 +150,9 @@ def __init__( # TODO UGLY self.ndim = 2 + self._coil_dim = 1 + self._spatial_dims = (2, 3) if self.ndim == 2 else (2, 3, 4) + def forward( self, input_image, @@ -199,14 +202,10 @@ def forward( sensitivity_map, input_image.unsqueeze(1) # (batch, 1, [slice,] height, width, complex) ) # shape (batch, coil, [slice,] height, width, complex) - coil_dim = 1 - # TODO(gy): Is if statement needed? Do 3D data pass from here? - spatial_dims = (2, 3) if mul.ndim == 5 else (2, 3, 4) - mr_forward = torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), - self.forward_operator(mul, dim=spatial_dims), + self.forward_operator(mul, dim=self._spatial_dims), ) # shape (batch, coil, [slice], height, width, complex) error = mr_forward - loglikelihood_scaling * torch.where( @@ -216,13 +215,13 @@ def forward( ) # shape (batch, coil, [slice], height, width, complex) mr_backward = self.backward_operator( - error, dim=spatial_dims + error, dim=self._spatial_dims ) # shape (batch, coil, [slice], height, width, complex) if sensitivity_map is not None: - out = T.complex_multiplication(T.conjugate(sensitivity_map), mr_backward).sum(coil_dim) + out = T.complex_multiplication(T.conjugate(sensitivity_map), mr_backward).sum(self._coil_dim) else: - out = mr_backward.sum(coil_dim) + out = mr_backward.sum(self._coil_dim) # out has shape (batch, complex=2, [slice], height, width) out = ( From 009ec9b9d8f00401c4005236e21eb4b4081d4d7f Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Fri, 3 Sep 2021 16:46:16 +0200 Subject: [PATCH 4/4] Code quality and documentation --- direct/nn/rim/rim.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/direct/nn/rim/rim.py b/direct/nn/rim/rim.py index ec1c9cb0..6c20522c 100644 --- a/direct/nn/rim/rim.py +++ b/direct/nn/rim/rim.py @@ -2,7 +2,7 @@ # Copyright (c) DIRECT Contributors import warnings -from typing import Callable, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -89,7 +89,11 @@ def __init__( nn.init.constant_(update_gate[-1].bias, 0.0) nn.init.constant_(out_gate[-1].bias, 0.0) - def forward(self, cell_input, previous_state): + def forward( + self, + cell_input: torch.Tensor, + previous_state: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Parameters @@ -104,8 +108,8 @@ def forward(self, cell_input, previous_state): (torch.Tensor, torch.Tensor) """ - new_states = [] - conv_skip = [] + new_states: List[torch.Tensor] = [] + conv_skip: List[torch.Tensor] = [] for idx in range(self.depth): if len(conv_skip) > 0: @@ -160,7 +164,7 @@ def forward( sensitivity_map, sampling_mask, loglikelihood_scaling=None, - ): + ) -> torch.Tensor: r""" Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils. $$ \frac{1}{\sigma^2} \sum_{i}^{\text{num coils}} @@ -281,7 +285,7 @@ def __init__( block = [nn.Conv2d(tch, out_ch, 1, padding=0)] self.out_blocks.append(nn.Sequential(*block)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: features = [] for block in self.conv_blocks: