diff --git a/README.md b/README.md index 7605eaf0..1fcabcee 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ See [getting_started.md](getting_started.md), check out the [documentation](http In the [projects](projects) folder examples are given on how to train models on public datasets. ## Baselines and trained models - +- [Recurrent Variational Network (RecurrentVarNet)](https://arxiv.org/abs/2111.09639) - [Recurrent Inference Machine (RIM)](https://www.sciencedirect.com/science/article/abs/pii/S1361841518306078) - [End-to-end Variational Network (VarNet)](https://arxiv.org/pdf/2004.06688.pdf) - [Learned Primal Dual Network (LDPNet)](https://arxiv.org/abs/1707.06474) diff --git a/direct/nn/recurrentvarnet/__init__.py b/direct/nn/recurrentvarnet/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/recurrentvarnet/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/recurrentvarnet/config.py b/direct/nn/recurrentvarnet/config.py new file mode 100644 index 00000000..c8df4f57 --- /dev/null +++ b/direct/nn/recurrentvarnet/config.py @@ -0,0 +1,20 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from dataclasses import dataclass +from typing import Optional, Tuple + +from direct.config.defaults import ModelConfig + + +@dataclass +class RecurrentVarNetConfig(ModelConfig): + num_steps: int = 15 # :math:`T` + recurrent_hidden_channels: int = 64 + recurrent_num_layers: int = 4 # :math:`n_l` + no_parameter_sharing: bool = True + learned_initializer: bool = True + initializer_initialization: Optional[str] = "sense" + initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64) # :math:`n_d` + initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4) # :math:`p` + initializer_multiscale: int = 1 diff --git a/direct/nn/recurrentvarnet/recurrentvarnet.py b/direct/nn/recurrentvarnet/recurrentvarnet.py new file mode 100644 index 00000000..dc466a03 --- /dev/null +++ b/direct/nn/recurrentvarnet/recurrentvarnet.py @@ -0,0 +1,385 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from direct.data.transforms import ( + conjugate, + complex_multiplication, + reduce_operator, + expand_operator, +) +from direct.nn.recurrent.recurrent import Conv2dGRU + + +class RecurrentInit(nn.Module): + """ + Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in + https://arxiv.org/abs/2111.09639. The RSI module learns to initialize the recurrent hidden state h_0, + input of the first RecurrentVarNet Block of the RecurrentVarNet. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + channels: Tuple[int, ...], + dilations: Tuple[int, ...], + depth: int = 2, + multiscale_depth: int = 1, + ): + """ + + Parameters + ---------- + in_channels : int + Input channels. + out_channels : int + Number of hidden channels of the recurrent unit of RecurrentVarNet Block. + channels : tuple + Channels :math:`n_d` in the convolutional layers of initializer. + dilations: tuple + Dilations :math:`p` of the convolutional layers of the initializer. + depth : int + RecurrentVarNet Block number of layers :math:`n_l`. + multiscale_depth : 1 + Number of feature layers to aggregate for the output, if 1, multi-scale context aggregation is disabled. + + """ + super().__init__() + + self.conv_blocks = nn.ModuleList() + self.out_blocks = nn.ModuleList() + self.depth = depth + self.multiscale_depth = multiscale_depth + tch = in_channels + for (curr_channels, curr_dilations) in zip(channels, dilations): + block = [ + nn.ReplicationPad2d(curr_dilations), + nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), + ] + tch = curr_channels + self.conv_blocks.append(nn.Sequential(*block)) + tch = np.sum(channels[-multiscale_depth:]) + for idx in range(depth): + block = [nn.Conv2d(tch, out_channels, 1, padding=0)] + self.out_blocks.append(nn.Sequential(*block)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + features = [] + for block in self.conv_blocks: + x = F.relu(block(x), inplace=True) + if self.multiscale_depth > 1: + features.append(x) + if self.multiscale_depth > 1: + x = torch.cat(features[-self.multiscale_depth :], dim=1) + output_list = [] + for block in self.out_blocks: + y = F.relu(block(x), inplace=True) + output_list.append(y) + out = torch.stack(output_list, dim=-1) + return out + + +class RecurrentVarNet(nn.Module): + """ + Recurrent Variational Network implementation as presented in https://arxiv.org/abs/2111.09639. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + in_channels: int = 2, + num_steps: int = 15, + recurrent_hidden_channels: int = 64, + recurrent_num_layers: int = 4, + no_parameter_sharing: bool = True, + learned_initializer: bool = False, + initializer_initialization: Optional[str] = None, + initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64), + initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4), + initializer_multiscale: int = 1, + **kwargs, + ): + """ + + Parameters + ---------- + forward_operator : Callable + Forward Operator. + backward_operator : Callable + Backward Operator. + num_steps : int + Number of iterations :math:`T`. + in_channels : int + Input channel number. Default is 2 for complex data. + recurrent_hidden_channels : int + Hidden channels number for the recurrent unit of the RecurrentVarNet Blocks. Default: 64. + recurrent_num_layers : int + Number of layers for the recurrent unit of the RecurrentVarNet Block (:math:`n_l`). Default: 4. + no_parameter_sharing : bool + If False, the same RecurrentVarNet Block is used for all num_steps. Default: True. + learned_initializer : bool + If True an RSI module is used. Default: False. + initializer_initialization : str, Optional + Type of initialization for the RSI module. Can be either 'sense', 'zero-filled' or 'input-image'. + Default: None. + initializer_channels : tuple + Channels :math:`n_d` in the convolutional layers of the RSI module. Default: (32, 32, 64, 64). + initializer_dilations : tuple + Dilations :math:`p` of the convolutional layers of the RSI module. Default: (1, 1, 2, 4). + initializer_multiscale : int + RSI module number of feature layers to aggregate for the output, if 1, multi-scale context aggregation + is disabled. Default: 1. + + """ + super(RecurrentVarNet, self).__init__() + + extra_keys = kwargs.keys() + for extra_key in extra_keys: + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + + self.initializer: Optional[nn.Module] = None + if ( + learned_initializer + and initializer_initialization is not None + and initializer_channels is not None + and initializer_dilations is not None + ): + if initializer_initialization not in [ + "sense", + "input_image", + "zero_filled", + ]: + raise ValueError( + f"Unknown initializer_initialization. Expected `sense`, `'input_image` or `zero_filled`." + f"Got {self.initializer_initialization}." + ) + self.initializer_initialization = initializer_initialization + self.initializer = RecurrentInit( + in_channels, + recurrent_hidden_channels, + channels=initializer_channels, + dilations=initializer_dilations, + depth=recurrent_num_layers, + multiscale_depth=initializer_multiscale, + ) + self.num_steps = num_steps + self.no_parameter_sharing = no_parameter_sharing + self.block_list: nn.Module = nn.ModuleList() + for _ in range(self.num_steps if self.no_parameter_sharing else 1): + self.block_list.append( + RecurrentVarNetBlock( + forward_operator=forward_operator, + backward_operator=backward_operator, + in_channels=in_channels, + hidden_channels=recurrent_hidden_channels, + num_layers=recurrent_num_layers, + ) + ) + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def compute_sense_init(self, kspace, sensitivity_map): + + input_image = complex_multiplication( + conjugate(sensitivity_map), + self.backward_operator(kspace, dim=self._spatial_dims), + ) + input_image = input_image.sum(self._coil_dim) + + return input_image + + def forward( + self, + masked_kspace: torch.Tensor, + sampling_mask: torch.Tensor, + sensitivity_map: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """ + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sampling_mask : torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + sensitivity_map : torch.Tensor + Coil sensitivities of shape (N, coil, height, width, complex=2). + + Returns + ------- + kspace_prediction: torch.Tensor + k-space prediction. + """ + + previous_state: Optional[torch.Tensor] = None + + if self.initializer is not None: + if self.initializer_initialization == "sense": + initializer_input_image = self.compute_sense_init( + kspace=masked_kspace, + sensitivity_map=sensitivity_map, + ).unsqueeze(self._coil_dim) + elif self.initializer_initialization == "input_image": + if "initial_image" not in kwargs: + raise ValueError( + f"`'initial_image` is required as input if initializer_initialization " + f"is {self.initializer_initialization}." + ) + initializer_input_image = kwargs["initial_image"].unsqueeze(self._coil_dim) + elif self.initializer_initialization == "zero_filled": + initializer_input_image = self.backward_operator(masked_kspace, dim=self._spatial_dims) + + previous_state = self.initializer( + self.forward_operator(initializer_input_image, dim=self._spatial_dims) + .sum(self._coil_dim) + .permute(0, 3, 1, 2) + ) + + kspace_prediction = masked_kspace.clone() + + for step in range(self.num_steps): + block = self.block_list[step] if self.no_parameter_sharing else self.block_list[0] + kspace_prediction, previous_state = block( + kspace_prediction, + masked_kspace, + sampling_mask, + sensitivity_map, + previous_state, + self._coil_dim, + self._complex_dim, + self._spatial_dims, + ) + + return kspace_prediction + + +class RecurrentVarNetBlock(nn.Module): + """ + Recurrent Variational Network Block as presented in https://arxiv.org/abs/2111.09639. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + in_channels: int = 2, + hidden_channels: int = 64, + num_layers: int = 4, + ): + """ + Parameters: + ----------- + forward_operator: Callable + Forward Fourier Transform. + backward_operator: Callable + Backward Fourier Transform. + in_channels: int, + Input channel number. Default is 2 for complex data. + hidden_channels: int, + Hidden channels. Default: 64. + num_layers: int, + Number of layers of :math:`n_l` recurrent unit. Default: 4. + + """ + super().__init__() + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + self.learning_rate = nn.Parameter(torch.tensor([1.0])) # :math:`\alpha_t` + self.regularizer = Conv2dGRU( + in_channels=in_channels, + hidden_channels=hidden_channels, + num_layers=num_layers, + replication_padding=True, + ) # Recurrent Unit of RecurrentVarNet Block :math:`\mathcal{H}_{\theta_t}` + + def forward( + self, + current_kspace: torch.Tensor, + masked_kspace: torch.Tensor, + sampling_mask: torch.Tensor, + sensitivity_map: torch.Tensor, + hidden_state: Union[None, torch.Tensor], + coil_dim: int = 1, + complex_dim: int = -1, + spatial_dims: Tuple[int, int] = (2, 3), + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parameters + ---------- + current_kspace: torch.Tensor + Current k-space prediction of shape (N, coil, height, width, complex=2). + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sampling_mask : torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + sensitivity_map : torch.Tensor + Coil sensitivities of shape (N, coil, height, width, complex=2). + hidden_state: torch.Tensor or None + ConvGRU hidden state of shape (N, hidden_channels, height, width, num_layers) if not None. Optional. + coil_dim: int + Coil dimension. Default: 1. + complex_dim: int + Channel/complex dimension. Default: -1. + spatial_dims: tuple of ints + Spatial dimensions. Default: (2, 3). + + Returns + ------- + new_kspace: torch.Tensor + New k-space prediction of shape (N, coil, height, width, complex=2). + hidden_state: torch.Tensor + Next hidden state of shape (N, hidden_channels, height, width, num_layers). + """ + + kspace_error = torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), + current_kspace - masked_kspace, + ) + + recurrent_term = torch.cat( + [ + reduce_operator( + self.backward_operator(kspace, dim=spatial_dims), + sensitivity_map, + dim=coil_dim, + ) + for kspace in torch.split(current_kspace, 2, complex_dim) + ], + dim=complex_dim, + ).permute(0, 3, 1, 2) + + recurrent_term, hidden_state = self.regularizer(recurrent_term, hidden_state) # :math:`w_t`, :math:`h_{t+1}` + recurrent_term = recurrent_term.permute(0, 2, 3, 1) + + recurrent_term = torch.cat( + [ + self.forward_operator( + expand_operator(image, sensitivity_map, dim=coil_dim), + dim=spatial_dims, + ) + for image in torch.split(recurrent_term, 2, complex_dim) + ], + dim=complex_dim, + ) + + new_kspace = current_kspace - self.learning_rate * kspace_error + recurrent_term + + return new_kspace, hidden_state diff --git a/direct/nn/recurrentvarnet/recurrentvarnet_engine.py b/direct/nn/recurrentvarnet/recurrentvarnet_engine.py new file mode 100644 index 00000000..61ce82d0 --- /dev/null +++ b/direct/nn/recurrentvarnet/recurrentvarnet_engine.py @@ -0,0 +1,476 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import time +from collections import defaultdict +from os import PathLike +from typing import Callable, DefaultDict, Dict, List, Optional + +import numpy as np +import torch +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader + +import direct.data.transforms as T +from direct.config import BaseConfig +from direct.engine import DoIterationOutput, Engine +from direct.functionals import SSIMLoss +from direct.utils import ( + communication, + detach_dict, + dict_to_device, + merge_list_of_dicts, + multiply_function, + reduce_list_of_dicts, +) +from direct.utils.communication import reduce_tensor_dict + + +class RecurrentVarNetEngine(Engine): + """ + Recurrent Variational Network Engine. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: int, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._complex_dim = -1 + self._coil_dim = 1 + self._spatial_dims = (2, 3) + + def _do_iteration( + self, + data: Dict[str, torch.Tensor], + loss_fns: Optional[Dict[str, Callable]] = None, + regularizer_fns: Optional[Dict[str, Callable]] = None, + ) -> DoIterationOutput: + + # loss_fns can be done, e.g. during validation + if loss_fns is None: + loss_fns = {} + + if regularizer_fns is None: + regularizer_fns = {} + + loss_dicts = [] + regularizer_dicts = [] + + data = dict_to_device(data, self.device) + + # sensitivity_map of shape (batch, coil, height, width, complex=2) + sensitivity_map = data["sensitivity_map"] + + if "sensitivity_model" in self.models: # SER Module + + # Move channels to first axis + sensitivity_map = data["sensitivity_map"].permute( + (0, 1, 4, 2, 3) + ) # shape (batch, coil, complex=2, height, width) + + sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute( + (0, 1, 3, 4, 2) + ) # has channel last: shape (batch, coil, height, width, complex=2) + + # The sensitivity map needs to be normalized such that + # So \sum_{i \in \text{coils}} S_i S_i^* = 1 + + sensitivity_map_norm = torch.sqrt( + ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim) + ) # shape (batch, height, width) + sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1) + data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm) + + with autocast(enabled=self.mixed_precision): + + output_kspace = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + ) + + output_image = T.root_sum_of_squares( + self.backward_operator(output_kspace, dim=self._spatial_dims), + dim=self._coil_dim, + ) # shape (batch, height, width) + + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + regularizer_dict = { + k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() + } + + for key, value in loss_dict.items(): + loss_dict[key] = value + loss_fns[key]( + output_image, + **data, + reduction="mean", + ) + + for key, value in regularizer_dict.items(): + regularizer_dict[key] = value + regularizer_fns[key]( + output_image, + **data, + ) + + loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dicts.append(detach_dict(loss_dict)) + regularizer_dicts.append( + detach_dict(regularizer_dict) + ) # Need to detach dict as this is only used for logging. + + # Add the loss dicts. + loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") + regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict, **regularizer_dict}, + ) + + def build_loss(self, **kwargs) -> Dict: + def get_resolution(**data): + """Be careful that this will use the cropping size of the FIRST sample in the batch.""" + return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None)) + + def l1_loss(source, reduction="mean", **data): + """ + Calculate L1 loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l1_loss + + def l2_loss(source, reduction="mean", **data): + """ + Calculate L2 loss (MSE) given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l2_loss + + def ssim_loss(source, reduction="mean", **data): + """ + Calculate SSIM loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + if reduction != "mean": + raise AssertionError( + f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}." + ) + + source_abs, target_abs = self.cropper(source, data["target"], resolution) + data_range = torch.tensor([target_abs.max()], device=target_abs.device) + + ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range) + + return ssim_loss + + # Build losses + loss_dict = {} + for curr_loss in self.cfg.training.loss.losses: # type: ignore + loss_fn = curr_loss.function + if loss_fn == "l1_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss) + elif loss_fn == "l2_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss) + elif loss_fn == "ssim_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss) + else: + raise ValueError(f"{loss_fn} not permissible.") + + return loss_dict + + @torch.no_grad() + def evaluate( + self, + data_loader: DataLoader, + loss_fns: Optional[Dict[str, Callable]], + regularizer_fns: Optional[Dict[str, Callable]] = None, + crop: Optional[str] = None, + is_validation_process: bool = True, + ): + """ + Validation process. Assumes that each batch only contains slices of the same volume *AND* that these + are sequentially ordered. + + Parameters + ---------- + data_loader : DataLoader + loss_fns : Dict[str, Callable], optional + regularizer_fns : Dict[str, Callable], optional + crop : str, optional + is_validation_process : bool + + Returns + ------- + loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + """ + self.models_to_device() + self.models_validation_mode() + torch.cuda.empty_cache() + + # Variables required for evaluation. + volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore + + # filenames can be in the volume_indices attribute of the dataset + num_for_this_process = None + all_filenames = None + if hasattr(data_loader.dataset, "volume_indices"): + all_filenames = list(data_loader.dataset.volume_indices.keys()) + num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys())) + self.logger.info( + f"Reconstructing a total of {len(all_filenames)} volumes. " + f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})." + ) + + filenames_seen = 0 + reconstruction_output: DefaultDict = defaultdict(list) + if is_validation_process: + targets_output: DefaultDict = defaultdict(list) + val_losses = [] + val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict) + last_filename = None + + # Container to for the slices which can be visualized in TensorBoard. + visualize_slices: List[np.ndarray] = [] + visualize_target: List[np.ndarray] = [] + # visualizations = {} + + extra_visualization_keys = ( + self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore + ) + + # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler + # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is + # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only + # contains data from one volume. + time_start = time.time() + + for iter_idx, data in enumerate(data_loader): + filenames = data.pop("filename") + if len(set(filenames)) != 1: + raise ValueError( + f"Expected a batch during validation to only contain filenames of one case. " + f"Got {set(filenames)}." + ) + + slice_nos = data.pop("slice_no") + scaling_factors = data["scaling_factor"] + + resolution = self.compute_resolution( + key=self.cfg.validation.crop, # type: ignore + reconstruction_size=data.get("reconstruction_size", None), + ) + + # Compute output and loss. + iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns) + output = iteration_output.output_image + loss_dict = iteration_output.data_dict + + loss_dict = detach_dict(loss_dict) + output = output.detach() + val_losses.append(loss_dict) + + # Output is complex-valued, and has to be cropped. This holds for both output and target. + # Output has shape (batch, complex, height, width) + output_abs = self.process_output( + output, + scaling_factors, + resolution=resolution, + ) + + if is_validation_process: + # Target has shape (batch, height, width) + target_abs = self.process_output( + data["target"].detach(), + scaling_factors, + resolution=resolution, + ) + for key in extra_visualization_keys: + curr_data = data[key].detach() + # Here we need to discover which keys are actually normalized or not + # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23 + + del output # Explicitly call delete to clear memory. + + # Aggregate volumes to be able to compute the metrics on complete volumes. + for idx, filename in enumerate(filenames): + if last_filename is None: + last_filename = filename # First iteration last_filename is not set. + + curr_slice = output_abs[idx].detach() + slice_no = int(slice_nos[idx].numpy()) + + reconstruction_output[filename].append((slice_no, curr_slice.cpu())) + + if is_validation_process: + targets_output[filename].append((slice_no, target_abs[idx].cpu())) + + is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"]) + reconstruction_conditions = [ + filename != last_filename, + is_last_element_of_last_batch, + ] + for condition in reconstruction_conditions: + if condition: + filenames_seen += 1 + + # Now we can ditch the reconstruction dict by reconstructing the volume, + # will take too much memory otherwise. + volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]]) + if is_validation_process: + target = torch.stack([_[1] for _ in targets_output[last_filename]]) + curr_metrics = { + metric_name: metric_fn(target, volume) + for metric_name, metric_fn in volume_metrics.items() + } + val_volume_metrics[last_filename] = curr_metrics + # Log the center slice of the volume + if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore + visualize_slices.append(volume[volume.shape[0] // 2]) + visualize_target.append(target[target.shape[0] // 2]) + + # Delete outputs from memory, and recreate dictionary. + # This is not needed when not in validation as we are actually interested + # in the iteration output. + del targets_output[last_filename] + del reconstruction_output[last_filename] + + if all_filenames: + log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:" + else: + log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:" + + self.logger.info( + f"{log_prefix} {last_filename}" + f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s." + ) + # restart timer + time_start = time.time() + last_filename = filename + + # Average loss dict + loss_dict = reduce_list_of_dicts(val_losses) + reduce_tensor_dict(loss_dict) + + communication.synchronize() + torch.cuda.empty_cache() + + all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics)) + if not is_validation_process: + return loss_dict, reconstruction_output + + return loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + def process_output(self, data, scaling_factors=None, resolution=None): + # data is of shape (batch, complex=2, height, width) + if scaling_factors is not None: + data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device) + + data = T.modulus_if_complex(data) + + if len(data.shape) == 3: # (batch, height, width) + data = data.unsqueeze(1) # Added channel dimension. + + if resolution is not None: + data = T.center_crop(data, resolution).contiguous() + + return data + + @staticmethod + def compute_resolution(key, reconstruction_size): + if key == "header": + # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over + # batches. + resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size] + # The volume sampler should give validation indices belonging to the *same* volume, so it should be + # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y). + resolution = [_[0] for _ in resolution][:-1] + elif key == "training": + resolution = key + elif not key: + resolution = None + else: + raise ValueError( + "Cropping should be either set to `header` to get the values from the header or " + "`training` to take the same value as training." + ) + return resolution + + def cropper(self, source, target, resolution): + """ + 2D source/target cropper + + Parameters: + ----------- + Source has shape (batch, height, width) + Target has shape (batch, height, width) + + """ + + if not resolution or all(_ == 0 for _ in resolution): + return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension. + + source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension. + target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension. + + return source_abs, target_abs + + def compute_model_per_coil(self, model_name, data): + """ + Computes model per coil. + """ + # data is of shape (batch, coil, complex=2, height, width) + output = [] + + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + output.append(self.models[model_name](subselected_data)) + output = torch.stack(output, dim=self._coil_dim) + + # output is of shape (batch, coil, complex=2, height, width) + return output diff --git a/model_zoo.md b/model_zoo.md index f8682023..ca9c933a 100644 --- a/model_zoo.md +++ b/model_zoo.md @@ -22,21 +22,23 @@ Models were trained on the Calgary-Campinas brain dataset. Training included 47 #### Validation Set (12 coils, 20 Volumes) -| Model | Name | Acceleration | Checkpoint | SSIM | pSNR | VIF | -|--------------|--------------|--------------|-----------------------------------------------------------------------|-------|------|-------| -|LPDNet | lpd | 5x | [96000](https://s3.aiforoncology.nl/direct-project/lpdnet.zip) | 0.937 | 35.6 | 0.953 | -|LPDNet | lpd | 10x | [97000](https://s3.aiforoncology.nl/direct-project/lpdnet.zip) | 0.901 | 32.2 | 0.919 | -|RIM | rim | 5x | [89000](https://s3.aiforoncology.nl/direct-project/rim.zip) | 0.932 | 35.0 | 0.964 | -|RIM | rim | 10x | [63000](https://s3.aiforoncology.nl/direct-project/rim.zip) | 0.891 | 31.7 | 0.911 | -|VarNet | varnet | 5x | [4000](https://s3.aiforoncology.nl/direct-project/varnet.zip) | 0.917 | 33.3 | 0.937 | -|VarNet | varnet | 10x | [3000](https://s3.aiforoncology.nl/direct-project/varnet.zip) | 0.862 | 29.9 | 0.861 | -|Joint-ICNet | jointicnet | 5x | [43000](https://s3.aiforoncology.nl/direct-project/jointicnet.zip) | 0.904 | 32.0 | 0.940 | -|Joint-ICNet | jointicnet | 10x | [42500](https://s3.aiforoncology.nl/direct-project/jointicnet.zip) | 0.854 | 29.4 | 0.853 | -|XPDNet | xpdnet | 5x | [16000](https://s3.aiforoncology.nl/direct-project/xpdnet.zip) | 0.907 | 32.3 | 0.965 | -|XPDNet | xpdnet | 10x | [14000](https://s3.aiforoncology.nl/direct-project/xpdnet.zip) | 0.855 | 29.7 | 0.837 | -|KIKI-Net | kikinet | 5x | [44500](https://s3.aiforoncology.nl/direct-project/kikinet.zip) | 0.888 | 29.6 | 0.919 | -|KIKI-Net | kikinet | 10x | [44500](https://s3.aiforoncology.nl/direct-project/kikinet.zip) | 0.833 | 27.5 | 0.856 | -|MultiDomainNet|multidomainnet| 5x | [50000](https://s3.aiforoncology.nl/direct-project/multidomainnet.zip)| 0.864 | 28.7 | 0.912 | -|MultiDomainNet|multidomainnet| 10x | [50000](https://s3.aiforoncology.nl/direct-project/multidomainnet.zip)| 0.810 | 26.8 | 0.812 | -|U-Net | unet | 5x | [10000](https://s3.aiforoncology.nl/direct-project/unet.zip) | 0.871 | 29.5 | 0.895 | -|U-Net | unet | 10x | [6000](https://s3.aiforoncology.nl/direct-project/unet.zip) | 0.821 | 27.8 | 0.837 | +| Model | Name | Acceleration | Checkpoint | SSIM | pSNR | VIF | +|---------------|---------------|--------------|-----------------------------------------------------------------------|-------|------|-------| +|RecurrentVarNet|recurrentvarnet| 5x | [148500]() | 0.943 | 36.1 | 0.964 | +|RecurrentVarNet|recurrentvarnet| 10x | [107000]() | 0.911 | 33.0 | 0.926 | +|LPDNet | lpd | 5x | [96000](https://s3.aiforoncology.nl/direct-project/lpdnet.zip) | 0.937 | 35.6 | 0.953 | +|LPDNet | lpd | 10x | [97000](https://s3.aiforoncology.nl/direct-project/lpdnet.zip) | 0.901 | 32.2 | 0.919 | +|RIM | rim | 5x | [89000](https://s3.aiforoncology.nl/direct-project/rim.zip) | 0.932 | 35.0 | 0.964 | +|RIM | rim | 10x | [63000](https://s3.aiforoncology.nl/direct-project/rim.zip) | 0.891 | 31.7 | 0.911 | +|VarNet | varnet | 5x | [4000](https://s3.aiforoncology.nl/direct-project/varnet.zip) | 0.917 | 33.3 | 0.937 | +|VarNet | varnet | 10x | [3000](https://s3.aiforoncology.nl/direct-project/varnet.zip) | 0.862 | 29.9 | 0.861 | +|Joint-ICNet | jointicnet | 5x | [43000](https://s3.aiforoncology.nl/direct-project/jointicnet.zip) | 0.904 | 32.0 | 0.940 | +|Joint-ICNet | jointicnet | 10x | [42500](https://s3.aiforoncology.nl/direct-project/jointicnet.zip) | 0.854 | 29.4 | 0.853 | +|XPDNet | xpdnet | 5x | [16000](https://s3.aiforoncology.nl/direct-project/xpdnet.zip) | 0.907 | 32.3 | 0.965 | +|XPDNet | xpdnet | 10x | [14000](https://s3.aiforoncology.nl/direct-project/xpdnet.zip) | 0.855 | 29.7 | 0.837 | +|KIKI-Net | kikinet | 5x | [44500](https://s3.aiforoncology.nl/direct-project/kikinet.zip) | 0.888 | 29.6 | 0.919 | +|KIKI-Net | kikinet | 10x | [44500](https://s3.aiforoncology.nl/direct-project/kikinet.zip) | 0.833 | 27.5 | 0.856 | +|MultiDomainNet |multidomainnet | 5x | [50000](https://s3.aiforoncology.nl/direct-project/multidomainnet.zip)| 0.864 | 28.7 | 0.912 | +|MultiDomainNet |multidomainnet | 10x | [50000](https://s3.aiforoncology.nl/direct-project/multidomainnet.zip)| 0.810 | 26.8 | 0.812 | +|U-Net | unet | 5x | [10000](https://s3.aiforoncology.nl/direct-project/unet.zip) | 0.871 | 29.5 | 0.895 | +|U-Net | unet | 10x | [6000](https://s3.aiforoncology.nl/direct-project/unet.zip) | 0.821 | 27.8 | 0.837 | diff --git a/projects/calgary_campinas/configs/base_recurrentvarnet.yaml b/projects/calgary_campinas/configs/base_recurrentvarnet.yaml new file mode 100644 index 00000000..2f89901e --- /dev/null +++ b/projects/calgary_campinas/configs/base_recurrentvarnet.yaml @@ -0,0 +1,139 @@ +# This model is a reproduction of the algorithm submitted in the Calgary-Campinas Multi-channel MR Reconstruction (MC-MRRec) Challenge +# and is among the top-two solutions. + +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches. R=5 + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5] + crop_outer_slices: false + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [10] + crop_outer_slices: false + # Two datasets, only difference is the shape, so the data can be collated for larger batches. R=10 + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [10] + crop_outer_slices: false + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5] + crop_outer_slices: false + batch_size: 2 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0005 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 1000000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [5] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [10] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif + - fastmri_nmse +model: + model_name: recurrentvarnet.recurrentvarnet.RecurrentVarNet + num_steps: 12 + recurrent_hidden_channels: 128 + recurrent_num_layers: 4 + initializer_initialization: sense + learned_initializer: true + initializer_channels: [32, 32, 64, 64] + initializer_dilations: [1, 1, 2, 4] + initializer_multiscale: 3 +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 8 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + crop_outer_slices: true + text_description: inference + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace diff --git a/tools/parse_metrics_log.py b/tools/parse_metrics_log.py index 01e8963f..149f18ec 100644 --- a/tools/parse_metrics_log.py +++ b/tools/parse_metrics_log.py @@ -1,5 +1,6 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors + import argparse import json import pathlib @@ -13,23 +14,28 @@ def parse_args(): description="Find the best checkpoint for a given metric", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("metrics_path", type=pathlib.Path, help="Path to metrics.json") parser.add_argument("key", type=str, help="Key to use to find the best checkpoint.") - + parser.add_argument( + "--max", dest="max", help="If True, this computes on maximum, else minimum value for key.", action="store_true" + ) return parser.parse_args() def main(): args = parse_args() + print([i.name for i in args.metrics_path.glob("*.pt")]) with open(args.metrics_path / "metrics.json", "r") as f: data = f.readlines() data = [json.loads(_) for _ in data] - x = np.asarray([(int(_["iteration"]), -_[args.key]) for _ in data if args.key in _]) - out = x[np.where(x[:, 1] == x[:, 1].max())][0] - + x = np.asarray([(int(_["iteration"]), _[args.key]) for _ in data if args.key in _]) + if args.max: + out = x[np.where(x[:, 1] == x[:, 1].max())][0] + else: + out = x[np.where(x[:, 1] == x[:, 1].min())][0] print(f"{args.key} - {int(out[0])}: {out[1]}") + print(x[np.where(x[:, 1] == 148520)][0]) if __name__ == "__main__":