diff --git a/README.md b/README.md index d72aa4fe..e5e26cf4 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![black](https://github.com/directgroup/direct/actions/workflows/black.yml/badge.svg)](https://github.com/directgroup/direct/actions/workflows/black.yml) # DIRECT -DIRECT is the Deep Image REConstruction Toolkit that implements state-of-the-art inverse problem solvers. It includes +DIRECT is the Deep Image REConstruction Toolkit that implements state-of-the-art inverse problem solvers. It stores inverse problem solvers such as the Learned Primal Dual algorithm and Recurrent Inference Machine, which were part of the winning solution in Facebook & NYUs FastMRI challenge in 2019 and the Calgary-Campinas MRI reconstruction challenge at MIDL 2020.
@@ -29,9 +29,9 @@ If you use DIRECT in your own research, or want to refer to baseline results pub ```BibTeX @misc{DIRECTTOOLKIT, - author = {Jonas Teuwen, Nikita Moriakov, Dimitrios Karkalousos, Matthan Caan, George Yiasemis}, - title = {DIRECT}, + author = {Teuwen, Jonas and Yiasemis, George and Moriakov, Nikita and Karkalousos, Dimitrios and Caan, Matthan}, + title = {DIRECT: Deep Image REConstruction Toolkit}, howpublished = {\url{https://github.com/directgroup/direct}}, - year = {2020} + year = {2021} } ``` diff --git a/direct/data/h5_data.py b/direct/data/h5_data.py index e6489ee1..68172b4d 100644 --- a/direct/data/h5_data.py +++ b/direct/data/h5_data.py @@ -124,7 +124,7 @@ def parse_filenames_data(self, filenames, extra_h5s=None, filter_slice=None): if len(filenames) < 5 or idx % (len(filenames) // 5) == 0 or len(filenames) == (idx + 1): self.logger.info(f"Parsing: {(idx + 1) / len(filenames) * 100:.2f}%.") try: - kspace = h5py.File(filename, "r")["kspace"] + kspace = np.array(h5py.File(filename, "r")["kspace"]) self.verify_extra_h5_integrity(filename, kspace.shape, extra_h5s=extra_h5s) except OSError as exc: diff --git a/direct/data/tests/test_generator.py b/direct/data/tests/test_fake.py similarity index 100% rename from direct/data/tests/test_generator.py rename to direct/data/tests/test_fake.py diff --git a/direct/data/tests/test_transforms.py b/direct/data/tests/test_transforms.py index 3343cae5..bbf1c86b 100644 --- a/direct/data/tests/test_transforms.py +++ b/direct/data/tests/test_transforms.py @@ -107,6 +107,24 @@ def test_modulus(shape): assert np.allclose(out_torch, out_numpy) +@pytest.mark.parametrize( + "shape", + [ + [3, 3], + [4, 6], + [10, 8, 4], + [3, 4, 3, 5], + ], +) +@pytest.mark.parametrize("complex", [True, False]) +def test_modulus_if_complex(shape, complex): + if complex: + shape += [ + 2, + ] + test_modulus(shape) + + @pytest.mark.parametrize( "shape, dims", [ @@ -166,7 +184,6 @@ def test_center_crop(shape, target_shape): [[8, 4], [4, 4]], ], ) -# @pytest.mark.parametrize("named", [True, False]) def test_complex_center_crop(shape, target_shape): shape = shape + [2] input = create_input(shape) @@ -262,3 +279,53 @@ def test_ifftshift(shape): out_torch = transforms.ifftshift(torch_tensor).numpy() out_numpy = np.fft.ifftshift(data) assert np.allclose(out_torch, out_numpy) + + +@pytest.mark.parametrize( + "shape, dim", + [ + [[3, 4, 5], 0], + [[3, 3, 4, 5], 1], + [[3, 6, 4, 5], 0], + [[3, 3, 6, 4, 5], 1], + ], +) +def test_expand_operator(shape, dim): + shape = shape + [ + 2, + ] + data = create_input(shape) # noqa + shape = shape[:dim] + shape[dim + 1 :] + sens = create_input(shape) # noqa + + out_torch = tensor_to_complex_numpy(transforms.expand_operator(data, sens, dim)) + + input_numpy = np.expand_dims(tensor_to_complex_numpy(data), dim) + input_sens_numpy = tensor_to_complex_numpy(sens) + out_numpy = input_sens_numpy * input_numpy + + assert np.allclose(out_torch, out_numpy) + + +@pytest.mark.parametrize( + "shape, dim", + [ + [[3, 4, 5], 0], + [[3, 3, 4, 5], 1], + [[3, 6, 4, 5], 0], + [[3, 3, 6, 4, 5], 1], + ], +) +def test_reduce_operator(shape, dim): + shape = shape + [ + 2, + ] + coil_data = create_input(shape) # noqa + sens = create_input(shape) # noqa + out_torch = tensor_to_complex_numpy(transforms.reduce_operator(coil_data, sens, dim)) + + input_numpy = tensor_to_complex_numpy(coil_data) + input_sens_numpy = tensor_to_complex_numpy(sens) + out_numpy = (input_sens_numpy.conj() * input_numpy).sum(dim) + + assert np.allclose(out_torch, out_numpy) diff --git a/direct/data/transforms.py b/direct/data/transforms.py index e0730e08..ad72d7e6 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -10,10 +10,9 @@ import numpy as np import torch import torch.fft -from packaging import version from direct.data.bbox import crop_to_bbox -from direct.utils import ensure_list, is_power_of_two +from direct.utils import ensure_list, is_power_of_two, is_complex_data from direct.utils.asserts import assert_complex, assert_same_shape @@ -222,7 +221,6 @@ def safe_divide(input_tensor: torch.Tensor, other_tensor: torch.Tensor) -> torch torch.tensor([0.0], dtype=input_tensor.dtype).to(input_tensor.device), input_tensor / other_tensor, ) - return data @@ -266,8 +264,7 @@ def align_as(input_tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor: input_shape = list(input_tensor.shape) other_shape = torch.tensor(other.shape, dtype=int) out_shape = torch.ones(len(other.shape), dtype=int) - # TODO(gy): Fix to ensure complex_last when [2,..., 2] or [..., N,..., N,...] in other.shape, - # "-input_shape.count(dim):" is a hack and might cause problems. + for dim in np.sort(np.unique(input_tensor.shape)): ind = torch.where(other_shape == dim)[0][-input_shape.count(dim) :] out_shape[ind] = dim @@ -292,7 +289,6 @@ def modulus(data: torch.Tensor) -> torch.Tensor: complex_axis = -1 if data.size(-1) == 2 else 1 return (data ** 2).sum(complex_axis).sqrt() # noqa - # return torch.view_as_complex(data).abs() def modulus_if_complex(data: torch.Tensor) -> torch.Tensor: @@ -307,11 +303,9 @@ def modulus_if_complex(data: torch.Tensor) -> torch.Tensor: ------- torch.Tensor """ - # TODO: This can be merged with modulus if the tensor is real. - try: + if is_complex_data(data, complex_last=False): return modulus(data) - except ValueError: - return data + return data def roll( @@ -436,7 +430,7 @@ def _complex_matrix_multiplication(input_tensor, other_tensor, mult_func): Parameters ---------- - x : torch.Tensor + input_tensor : torch.Tensor other_tensor : torch.Tensor mult_func : Callable Multiplication function e.g. torch.bmm or torch.mm @@ -466,6 +460,7 @@ def complex_mm(input_tensor, other_tensor): ---------- input_tensor : torch.Tensor other_tensor : torch.Tensor + Returns ------- torch.Tensor @@ -501,10 +496,6 @@ def conjugate(data: torch.Tensor) -> torch.Tensor: ------- torch.Tensor """ - # assert_complex(data, complex_last=True) - # data = torch.view_as_real( - # torch.view_as_complex(data).conj() - # ) assert_complex(data, complex_last=True) data = data.clone() # Clone is required as the data in the next line is changed in-place. data[..., 1] = data[..., 1] * -1.0 @@ -574,11 +565,12 @@ def tensor_to_complex_numpy(data: torch.Tensor) -> np.ndarray: return data[..., 0] + 1j * data[..., 1] -def root_sum_of_squares(data: torch.Tensor, dim: int = 0) -> torch.Tensor: - r""" +def root_sum_of_squares(data: torch.Tensor, dim: int = 0, complex_dim: int = -1) -> torch.Tensor: + """ Compute the root sum of squares (RSS) transform along a given dimension of the input tensor. - $$x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}$$ + .. math:: + x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2} Parameters ---------- @@ -588,16 +580,16 @@ def root_sum_of_squares(data: torch.Tensor, dim: int = 0) -> torch.Tensor: dim : int Coil dimension. Default is 0 as the first dimension is always the coil dimension. + complex_dim : int + Complex channel dimension. Default is -1. If data not complex this is ignored. Returns ------- torch.Tensor : RSS of the input tensor. """ - try: - assert_complex(data, complex_last=True) - complex_index = -1 - return torch.sqrt((data ** 2).sum(complex_index).sum(dim)) - except ValueError: - return torch.sqrt((data ** 2).sum(dim)) + if is_complex_data(data): + return torch.sqrt((data ** 2).sum(complex_dim).sum(dim)) + + return torch.sqrt((data ** 2).sum(dim)) def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor: @@ -760,5 +752,71 @@ def complex_random_crop( if len(output) == 1: return output[0] - return output + + +def reduce_operator( + coil_data: torch.Tensor, + sensitivity_map: torch.Tensor, + dim: int = 0, +) -> torch.Tensor: + """ + Given zero-filled reconstructions from multiple coils :math: \{x_i\}_{i=1}^{N_c} and coil sensitivity maps + :math: \{S_i\}_{i=1}^{N_c} it returns + .. math:: + R(x_1, .., x_{N_c}, S_1, .., S_{N_c}) = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i. + + From paper End-to-End Variational Networks for Accelerated MRI Reconstruction. + + Parameters + ---------- + coil_data : torch.Tensor + Zero-filled reconstructions from coils. Should be a complex tensor (with complex dim of size 2). + sensitivity_map: torch.Tensor + Coil sensitivity maps. Should be complex tensor (with complex dim of size 2). + dim: int + Coil dimension. Default: 0. + + Returns + ------- + torch.Tensor: + Combined individual coil images. + """ + + assert_complex(coil_data, complex_last=True) + assert_complex(sensitivity_map, complex_last=True) + + return complex_multiplication(conjugate(sensitivity_map), coil_data).sum(dim) + + +def expand_operator( + data: torch.Tensor, + sensitivity_map: torch.Tensor, + dim: int = 0, +) -> torch.Tensor: + """ + Given a reconstructed image x and coil sensitivity maps :math: \{S_i\}_{i=1}^{N_c}, it returns + .. math:: + \Epsilon(x) = (S_1 \times x, .., S_{N_c} \times x) = (x_1, .., x_{N_c}). + + From paper End-to-End Variational Networks for Accelerated MRI Reconstruction. + + Parameters + ---------- + data : torch.Tensor + Image data. Should be a complex tensor (with complex dim of size 2). + sensitivity_map: torch.Tensor + Coil sensitivity maps. Should be complex tensor (with complex dim of size 2). + dim: int + Coil dimension. Default: 0. + + Returns + ------- + torch.Tensor: + Zero-filled reconstructions from each coil. + """ + + assert_complex(data, complex_last=True) + assert_complex(sensitivity_map, complex_last=True) + + return complex_multiplication(sensitivity_map, data.unsqueeze(dim)) diff --git a/direct/engine.py b/direct/engine.py index 345c9a7f..533c5a5c 100644 --- a/direct/engine.py +++ b/direct/engine.py @@ -11,7 +11,7 @@ import warnings from abc import ABC, abstractmethod from collections import namedtuple -from typing import Callable, Dict, List, Optional, TypedDict, Union +from typing import Callable, Dict, List, Optional, Union import numpy as np import torch @@ -278,13 +278,6 @@ def training_loop( fail_counter = 0 for data, iter_idx in zip(data_loader, range(start_iter, total_iter)): - # 2D data is batched and contains keys: - # "filename_slice", "slice_no" - # "sampling_mask" of shape: (batch, 1, height, width, 1) - # "sensitivity_map" of shape: (batch, coil, height, width, complex=2) - # "target" of shape: (batch, height, width) - # "masked_kspace" of shape: (batch, coil, height, width, complex=2) - if iter_idx == 0: self.log_first_training_example_and_model(data) diff --git a/direct/environment.py b/direct/environment.py index efebe766..1d3bffd0 100644 --- a/direct/environment.py +++ b/direct/environment.py @@ -1,8 +1,6 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -# pylint: disable = E1101 - import argparse import logging import os @@ -17,6 +15,7 @@ from torch.utils import collect_env import direct.utils.logging +from direct.utils.logging import setup from direct.config.defaults import DefaultConfig, InferenceConfig, TrainingConfig, ValidationConfig from direct.utils import communication, count_parameters, str_to_class @@ -74,7 +73,7 @@ def setup_logging(machine_rank, output_directory, run_name, cfg_filename, cfg, d # Setup logging log_file = output_directory / f"log_{machine_rank}_{communication.get_local_rank()}.txt" - direct.utils.logging.setup( + setup( use_stdout=communication.is_main_process() or debug, filename=log_file, log_level=("INFO" if not debug else "DEBUG"), @@ -245,13 +244,13 @@ def setup_common_environment( dataset_cfg_from_file = extract_names(cfg_from_file[key].datasets) for idx, (dataset_name, dataset_config) in enumerate(dataset_cfg_from_file): cfg_from_file_new[key].datasets[idx] = dataset_config - cfg[key].datasets.append(load_dataset_config(dataset_name)) + cfg[key].datasets.append(load_dataset_config(dataset_name)) # pylint: disable = E1136 else: dataset_name, dataset_config = extract_names(cfg_from_file[key].dataset) cfg_from_file_new[key].dataset = dataset_config - cfg[key].dataset = load_dataset_config(dataset_name) + cfg[key].dataset = load_dataset_config(dataset_name) # pylint: disable = E1136 - cfg[key] = OmegaConf.merge(cfg[key], cfg_from_file_new[key]) + cfg[key] = OmegaConf.merge(cfg[key], cfg_from_file_new[key]) # pylint: disable = E1136, E1137 # sys.exit() # Make configuration read only. # TODO(jt): Does not work when indexing config lists. diff --git a/direct/launch.py b/direct/launch.py index 17f32c42..eb28dc8a 100644 --- a/direct/launch.py +++ b/direct/launch.py @@ -139,10 +139,10 @@ def _distributed_worker( if communication._LOCAL_PROCESS_GROUP is not None: raise RuntimeError num_machines = world_size // num_gpus_per_machine - for i in range(num_machines): - ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)) + for idx in range(num_machines): + ranks_on_i = list(range(idx * num_gpus_per_machine, (idx + 1) * num_gpus_per_machine)) pg = dist.new_group(ranks_on_i) - if i == machine_rank: + if idx == machine_rank: communication._LOCAL_PROCESS_GROUP = pg main_func(*args) diff --git a/direct/nn/conv/__init__.py b/direct/nn/conv/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/conv/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/conv/conv.py b/direct/nn/conv/conv.py new file mode 100644 index 00000000..2da353e7 --- /dev/null +++ b/direct/nn/conv/conv.py @@ -0,0 +1,51 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import torch.nn as nn + + +class Conv2d(nn.Module): + """ + Implementation of a simple cascade of 2D convolutions. If batchnorm is set to True, batch normalization + layer is applied after each convolution. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, n_convs=3, activation=nn.PReLU(), batchnorm=False): + """ + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + hidden_channels : int + Number of hidden channels. + n_convs : int + Number of convolutional layers. + activation : nn.Module + Activation function. + batchnorm : bool + If True a batch normalization layer is applied after every convolution. + """ + super().__init__() + + self.conv = [] + for idx in range(n_convs): + self.conv.append( + nn.Conv2d( + in_channels if idx == 0 else hidden_channels, + hidden_channels if idx != n_convs - 1 else out_channels, + kernel_size=3, + padding=1, + ) + ) + if batchnorm: + self.conv.append(nn.BatchNorm2d(hidden_channels if idx != n_convs - 1 else out_channels, eps=1e-4)) + if idx != n_convs - 1: + self.conv.append(activation) + self.conv = nn.Sequential(*self.conv) + + def forward(self, x): + + return self.conv(x) diff --git a/direct/nn/conv/tests/__init__.py b/direct/nn/conv/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/conv/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/conv/tests/test_conv.py b/direct/nn/conv/tests/test_conv.py new file mode 100644 index 00000000..072c8581 --- /dev/null +++ b/direct/nn/conv/tests/test_conv.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch +import torch.nn as nn + +from direct.nn.conv.conv import Conv2d + + +def create_input(shape): + + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [3, 2, 32, 32], + [3, 2, 16, 16], + ], +) +@pytest.mark.parametrize( + "out_channels", + [3, 5], +) +@pytest.mark.parametrize( + "hidden_channels", + [16, 8], +) +@pytest.mark.parametrize( + "n_convs", + [2, 4], +) +@pytest.mark.parametrize( + "act", + [nn.ReLU(), nn.PReLU()], +) +@pytest.mark.parametrize( + "batchnorm", + [True, False], +) +def test_conv(shape, out_channels, hidden_channels, n_convs, act, batchnorm): + model = Conv2d(shape[1], out_channels, hidden_channels, n_convs, act, batchnorm) + + data = create_input(shape).cpu() + + out = model(data) + + assert list(out.shape) == [shape[0]] + [out_channels] + shape[2:] diff --git a/direct/nn/crossdomain/__init__.py b/direct/nn/crossdomain/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/crossdomain/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/crossdomain/crossdomain.py b/direct/nn/crossdomain/crossdomain.py new file mode 100644 index 00000000..18129e5a --- /dev/null +++ b/direct/nn/crossdomain/crossdomain.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn + +import direct.data.transforms as T + + +class CrossDomainNetwork(nn.Module): + """ + This performs optimisation in both, k-space ("K") and image ("I") domains according to domain_sequence. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + image_model_list: nn.Module, + kspace_model_list: Optional[Union[nn.Module, None]] = None, + domain_sequence: str = "KIKI", + image_buffer_size: int = 1, + kspace_buffer_size: int = 1, + normalize_image: bool = False, + **kwargs, + ): + """ + + Parameters + ---------- + forward_operator : Callable + Forward Operator. + backward_operator : Callable + Backward Operator. + image_model_list : nn.Module + Image domain model list. + kspace_model_list : Optional[nn.Module] + K-space domain model list. If set to None, a correction step is applied. Default: None. + domain_sequence : str + Domain sequence containing only "K" (k-space domain) and/or "I" (image domain). Default: "KIKI". + image_buffer_size : int + Image buffer size. Default: 1. + kspace_buffer_size : int + K-space buffer size. Default: 1. + normalize_image : bool + If True, input is normalized. Default: False. + kwargs : dict + Keyword Arguments. + """ + super().__init__() + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + domain_sequence = [domain_name for domain_name in domain_sequence.strip()] + if not set(domain_sequence).issubset({"K", "I"}): + raise ValueError(f"Invalid domain sequence. Got {domain_sequence}. Should only contain 'K' and 'I'.") + + if kspace_model_list is not None: + if len(kspace_model_list) != domain_sequence.count("K"): + raise ValueError(f"K-space domain steps do not match k-space model list length.") + + if len(image_model_list) != domain_sequence.count("I"): + raise ValueError(f"Image domain steps do not match image model list length.") + + self.domain_sequence = domain_sequence + + self.kspace_model_list = kspace_model_list + self.kspace_buffer_size = kspace_buffer_size + + self.image_model_list = image_model_list + self.image_buffer_size = image_buffer_size + + self.normalize_image = normalize_image + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def kspace_correction(self, block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map, masked_kspace): + + forward_buffer = [ + self._forward_operator( + image.clone(), + sampling_mask, + sensitivity_map, + ) + for image in torch.split(image_buffer, 2, self._complex_dim) + ] + + forward_buffer = torch.cat(forward_buffer, self._complex_dim) + kspace_buffer = torch.cat([kspace_buffer, forward_buffer, masked_kspace], self._complex_dim) + + if self.kspace_model_list is not None: + kspace_buffer = self.kspace_model_list[block_idx](kspace_buffer.permute(0, 1, 4, 2, 3)).permute( + 0, 1, 3, 4, 2 + ) + else: + kspace_buffer = kspace_buffer[..., :2] - kspace_buffer[..., 2:4] + + return kspace_buffer + + def image_correction(self, block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map): + backward_buffer = [ + self._backward_operator(kspace.clone(), sampling_mask, sensitivity_map) + for kspace in torch.split(kspace_buffer, 2, self._complex_dim) + ] + backward_buffer = torch.cat(backward_buffer, self._complex_dim) + + image_buffer = torch.cat([image_buffer, backward_buffer], self._complex_dim).permute(0, 3, 1, 2) + image_buffer = self.image_model_list[block_idx](image_buffer).permute(0, 2, 3, 1) + + return image_buffer + + def _forward_operator(self, image, sampling_mask, sensitivity_map): + forward = torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=image.dtype).to(image.device), + self.forward_operator(T.expand_operator(image, sensitivity_map, self._coil_dim), dim=self._spatial_dims), + ) + return forward + + def _backward_operator(self, kspace, sampling_mask, sensitivity_map): + backward = T.reduce_operator( + self.backward_operator( + torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), + kspace, + ), + self._spatial_dims, + ), + sensitivity_map, + self._coil_dim, + ) + return backward + + def forward( + self, + masked_kspace: torch.Tensor, + sampling_mask: torch.Tensor, + sensitivity_map: torch.Tensor, + scaling_factor: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sampling_mask : torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + sensitivity_map : torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). + scaling_factor : Optional[torch.Tensor] + Scaling factor of shape (N,). If None, no scaling is applied. Default: None. + + Returns + ------- + out_image : torch.Tensor + Output image of shape (N, height, width, complex=2). + """ + input_image = self._backward_operator(masked_kspace, sampling_mask, sensitivity_map) + + if self.normalize_image and scaling_factor is not None: + input_image = input_image / scaling_factor ** 2 + masked_kspace = masked_kspace / scaling_factor ** 2 + + image_buffer = torch.cat([input_image] * self.image_buffer_size, self._complex_dim).to(masked_kspace.device) + + kspace_buffer = torch.cat([masked_kspace] * self.kspace_buffer_size, self._complex_dim).to( + masked_kspace.device + ) + + kspace_block_idx, image_block_idx = 0, 0 + for block_domain in self.domain_sequence: + if block_domain == "K": + kspace_buffer = self.kspace_correction( + kspace_block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map, masked_kspace + ) + kspace_block_idx += 1 + else: + image_buffer = self.image_correction( + image_block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map + ) + image_block_idx += 1 + + if self.normalize_image and scaling_factor is not None: + image_buffer = image_buffer * scaling_factor ** 2 + + out_image = image_buffer[..., :2] + return out_image diff --git a/direct/nn/crossdomain/multicoil.py b/direct/nn/crossdomain/multicoil.py new file mode 100644 index 00000000..f8e0fb37 --- /dev/null +++ b/direct/nn/crossdomain/multicoil.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import torch +import torch.nn as nn + + +class MultiCoil(nn.Module): + """ + This makes the forward pass of multi-coil data of shape (N, N_coils, H, W, C) to a model. If coil_to_batch is set + to True, coil dimension is moved to the batch dimension. Otherwise, it passes to the model each coil-data + individually. + """ + + def __init__(self, model: nn.Module, coil_dim: int = 1, coil_to_batch: bool = False): + """ + + Parameters + ---------- + model : nn.Module + Any nn.Module that takes as input with 4D data (N, H, W, C). Typically a convolutional-like model. + coil_dim : int + Coil dimension. Default: 1. + coil_to_batch : bool + If True batch and coil dimensions are merged when forwarded by the model and unmerged when outputted. + Otherwise, input is forwarded to the model per coil. + """ + super().__init__() + + self.model = model + self.coil_to_batch = coil_to_batch + self._coil_dim = coil_dim + + def _compute_model_per_coil(self, data: torch.Tensor) -> torch.Tensor: + output = [] + + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + output.append(self.model(subselected_data)) + output = torch.stack(output, dim=self._coil_dim) + return output + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + + Parameters + ---------- + x : torch.Tensor + Multi-coil input of shape (N, coil, height, width, in_channels). + + Returns + ------- + out : torch.Tensor + Multi-coil output of shape (N, coil, height, width, out_channels). + """ + if self.coil_to_batch: + x = x.clone() + batch, coil, height, width, channels = x.size() + + x = x.reshape(batch * coil, height, width, channels).permute(0, 3, 1, 2).contiguous() + x = self.model(x).permute(0, 2, 3, 1) + x = x.reshape(batch, coil, height, width, -1) + else: + x = self._compute_model_per_coil(x).contiguous() + + return x diff --git a/direct/nn/didn/__init__.py b/direct/nn/didn/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/didn/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/didn/didn.py b/direct/nn/didn/didn.py new file mode 100644 index 00000000..91d5947e --- /dev/null +++ b/direct/nn/didn/didn.py @@ -0,0 +1,263 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Subpixel(nn.Module): + """ + Subpixel convolution layer for up-scaling of low resolution features at super-resolution as implemented + in https://ieeexplore.ieee.org/document/9025411. + """ + + def __init__(self, in_channels, out_channels, upscale_factor, kernel_size, padding=0): + super().__init__() + self.conv = nn.Conv2d( + in_channels, out_channels * upscale_factor ** 2, kernel_size=kernel_size, padding=padding + ) + self.pixelshuffle = nn.PixelShuffle(upscale_factor) + + def forward(self, x): + return self.pixelshuffle(self.conv(x)) + + +class ReconBlock(nn.Module): + """ + Reconstruction Block of DIDN model as implemented in https://ieeexplore.ieee.org/document/9025411. + """ + + def __init__(self, in_channels, num_convs): + super().__init__() + self.convs = nn.ModuleList( + [ + nn.Sequential( + *[ + nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1), + nn.PReLU(), + ] + ) + for _ in range(num_convs - 1) + ] + ) + self.convs.append(nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1)) + self.num_convs = num_convs + + def forward(self, input): + + output = input.clone() + for idx in range(self.num_convs): + output = self.convs[idx](output) + + return input + output + + +class DUB(nn.Module): + """ + Down-up block (DUB) for DIDN model as implemented in https://ieeexplore.ieee.org/document/9025411. + """ + + def __init__( + self, + in_channels, + out_channels, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + # Scale 1 + self.conv1_1 = nn.Sequential(*[nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), nn.PReLU()] * 2) + self.down1 = nn.Conv2d(in_channels, in_channels * 2, kernel_size=3, stride=2, padding=1) + # Scale 2 + self.conv2_1 = nn.Sequential( + *[nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=3, padding=1), nn.PReLU()] + ) + self.down2 = nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=3, stride=2, padding=1) + # Scale 3 + self.conv3_1 = nn.Sequential( + *[ + nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size=3, padding=1), + nn.PReLU(), + ] + ) + self.up1 = nn.Sequential( + *[ + # nn.Conv2d(in_channels * 4, in_channels * 8, kernel_size=1), + Subpixel(in_channels * 4, in_channels * 2, 2, 1, 0) + ] + ) + # Scale 2 + self.conv_agg_1 = nn.Conv2d(in_channels * 4, in_channels * 2, kernel_size=1) + self.conv2_2 = nn.Sequential( + *[ + nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=3, padding=1), + nn.PReLU(), + ] + ) + self.up2 = nn.Sequential( + *[ + # nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1), + Subpixel(in_channels * 2, in_channels, 2, 1, 0) + ] + ) + # Scale 1 + self.conv_agg_2 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1) + self.conv1_2 = nn.Sequential(*[nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), nn.PReLU()] * 2) + self.conv_out = nn.Sequential(*[nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), nn.PReLU()]) + + @staticmethod + def pad(x): + padding = [0, 0, 0, 0] + + if x.shape[-2] % 2 != 0: + padding[3] = 1 # Padding right - width + if x.shape[-1] % 2 != 0: + padding[1] = 1 # Padding bottom - height + if sum(padding) != 0: + x = F.pad(x, padding, "reflect") + return x + + @staticmethod + def crop_to_shape(x, shape): + h, w = x.shape[-2:] + + if h > shape[0]: + x = x[:, :, : shape[0], :] + if w > shape[1]: + x = x[:, :, :, : shape[1]] + return x + + def forward(self, x): + x1 = self.pad(x.clone()) + x1 = x1 + self.conv1_1(x1) + x2 = self.down1(x1) + x2 = x2 + self.conv2_1(x2) + out = self.down2(x2) + out = out + self.conv3_1(out) + out = self.up1(out) + out = torch.cat([x2, self.crop_to_shape(out, x2.shape[-2:])], dim=1) + out = self.conv_agg_1(out) + out = out + self.conv2_2(out) + out = self.up2(out) + out = torch.cat([x1, self.crop_to_shape(out, x1.shape[-2:])], dim=1) + out = self.conv_agg_2(out) + out = out + self.conv1_2(out) + out = x + self.crop_to_shape(self.conv_out(out), x.shape[-2:]) + return out + + +class DIDN(nn.Module): + """ + Deep Iterative Down-up convolutional Neural network (DIDN) implementation as in + https://ieeexplore.ieee.org/document/9025411. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int = 128, + num_dubs: int = 6, + num_convs_recon: int = 9, + skip_connection: bool = False, + ): + """ + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + hidden_channels : int + Number of hidden channels. First convolution out_channels. Default: 128. + num_dubs : int + Number of DUB networks. Default: 6. + num_convs_recon : int + Number of ReconBlock convolutions. Default: 9. + skip_connection : bool + Use skip connection. Default: False. + """ + super().__init__() + self.conv_in = nn.Sequential( + *[nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, padding=1), nn.PReLU()] + ) + self.down = nn.Conv2d( + in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=3, + stride=2, + padding=1, + ) + self.dubs = nn.ModuleList( + [DUB(in_channels=hidden_channels, out_channels=hidden_channels) for _ in range(num_dubs)] + ) + self.recon_block = ReconBlock(in_channels=hidden_channels, num_convs=num_convs_recon) + self.recon_agg = nn.Conv2d(in_channels=hidden_channels * num_dubs, out_channels=hidden_channels, kernel_size=1) + self.conv = nn.Sequential( + *[ + nn.Conv2d( + in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=3, + padding=1, + ), + nn.PReLU(), + ] + ) + self.up2 = Subpixel(hidden_channels, hidden_channels, 2, 1) + self.conv_out = nn.Conv2d( + in_channels=hidden_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ) + self.num_dubs = num_dubs + self.skip_connection = (in_channels == out_channels) and skip_connection + + @staticmethod + def crop_to_shape(x, shape): + h, w = x.shape[-2:] + + if h > shape[0]: + x = x[:, :, : shape[0], :] + if w > shape[1]: + x = x[:, :, :, : shape[1]] + return x + + def forward(self, x, channel_dim=1): + """ + + Parameters + ---------- + x : torch.Tensor + Input tensor. + channel_dim : int + Channel dimension. Default: 1. + + Returns + ------- + out : torch.Tensor + Output tensor. + """ + out = self.conv_in(x) + out = self.down(out) + + dub_outs = [] + for dub in self.dubs: + out = dub(out) + dub_outs.append(out) + + out = [self.recon_block(dub_out) for dub_out in dub_outs] + out = self.recon_agg(torch.cat(out, dim=channel_dim)) + out = self.conv(out) + out = self.up2(out) + out = self.conv_out(out) + out = self.crop_to_shape(out, x.shape[-2:]) + + if self.skip_connection: + out = x + out + return out diff --git a/direct/nn/didn/tests/__init__.py b/direct/nn/didn/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/didn/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/didn/tests/test_didn.py b/direct/nn/didn/tests/test_didn.py new file mode 100644 index 00000000..80c374d7 --- /dev/null +++ b/direct/nn/didn/tests/test_didn.py @@ -0,0 +1,51 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch + +from direct.nn.didn.didn import DIDN + + +def create_input(shape): + + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [3, 2, 32, 32], + [3, 2, 16, 16], + ], +) +@pytest.mark.parametrize( + "out_channels", + [3, 5], +) +@pytest.mark.parametrize( + "hidden_channels", + [16, 8], +) +@pytest.mark.parametrize( + "n_dubs", + [3, 4], +) +@pytest.mark.parametrize( + "num_convs_recon", + [3, 4], +) +@pytest.mark.parametrize( + "skip", + [True, False], +) +def test_didn(shape, out_channels, hidden_channels, n_dubs, num_convs_recon, skip): + model = DIDN(shape[1], out_channels, hidden_channels, n_dubs, num_convs_recon, skip) + + data = create_input(shape).cpu() + + out = model(data) + + assert list(out.shape) == [shape[0]] + [out_channels] + shape[2:] diff --git a/direct/nn/jointicnet/__init__.py b/direct/nn/jointicnet/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/jointicnet/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/jointicnet/config.py b/direct/nn/jointicnet/config.py new file mode 100644 index 00000000..ad9184f8 --- /dev/null +++ b/direct/nn/jointicnet/config.py @@ -0,0 +1,21 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from dataclasses import dataclass + +from direct.config.defaults import ModelConfig + + +@dataclass +class JointICNetConfig(ModelConfig): + num_iter: int = 10 + use_norm_unet: bool = False + image_unet_num_filters: int = 8 + image_unet_num_pool_layers: int = 4 + image_unet_dropout: float = 0.0 + kspace_unet_num_filters: int = 8 + kspace_unet_num_pool_layers: int = 4 + kspace_unet_dropout: float = 0.0 + sens_unet_num_filters: int = 8 + sens_unet_num_pool_layers: int = 4 + sens_unet_dropout: float = 0.0 diff --git a/direct/nn/jointicnet/jointicnet.py b/direct/nn/jointicnet/jointicnet.py new file mode 100644 index 00000000..cd52f3bf --- /dev/null +++ b/direct/nn/jointicnet/jointicnet.py @@ -0,0 +1,214 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from typing import Callable + +import torch +import torch.nn as nn + +import direct.data.transforms as T +from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d + + +class JointICNet(nn.Module): + """ + Joint-ICNet implementation as in "Joint Deep Model-based MR Image and Coil Sensitivity Reconstruction Network + (Joint-ICNet) for Fast MRI" submitted to the fastmri challenge. + + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + num_iter: int = 10, + use_norm_unet: bool = False, + **kwargs, + ): + """ + + Parameters + ---------- + forward_operator : Callable + Forward Transform. + backward_operator : Callable + Backward Transform. + num_iter : int + Number of unrolled iterations. Default: 10. + use_norm_unet : bool + If True, a Normalized U-Net is used. Default: False. + kwargs: dict + Image, k-space and sensitivity-map U-Net models keyword-arguments. + """ + super().__init__() + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.num_iter = num_iter + + unet_architecture = NormUnetModel2d if use_norm_unet else UnetModel2d + + self.image_model = unet_architecture( + in_channels=2, + out_channels=2, + num_filters=kwargs.get("image_unet_num_filters", 8), + num_pool_layers=kwargs.get("image_unet_num_pool_layers", 4), + dropout_probability=kwargs.get("image_unet_dropout", 0.0), + ) + self.kspace_model = unet_architecture( + in_channels=2, + out_channels=2, + num_filters=kwargs.get("kspace_unet_num_filters", 8), + num_pool_layers=kwargs.get("kspace_unet_num_pool_layers", 4), + dropout_probability=kwargs.get("kspace_unet_dropout", 0.0), + ) + self.sens_model = unet_architecture( + in_channels=2, + out_channels=2, + num_filters=kwargs.get("sens_unet_num_filters", 8), + num_pool_layers=kwargs.get("sens_unet_num_pool_layers", 4), + dropout_probability=kwargs.get("sens_unet_dropout", 0.0), + ) + self.conv_out = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=1) + + self.reg_param_I = nn.Parameter(torch.ones(num_iter)) + self.reg_param_F = nn.Parameter(torch.ones(num_iter)) + self.reg_param_C = nn.Parameter(torch.ones(num_iter)) + + self.lr_image = nn.Parameter(torch.ones(num_iter)) + self.lr_sens = nn.Parameter(torch.ones(num_iter)) + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def _image_model(self, image): + image = image.permute(0, 3, 1, 2) + return self.image_model(image).permute(0, 2, 3, 1).contiguous() + + def _kspace_model(self, kspace): + kspace = kspace.permute(0, 3, 1, 2) + return self.kspace_model(kspace).permute(0, 2, 3, 1).contiguous() + + def _sens_model(self, sensitivity_map): + return ( + self._compute_model_per_coil(self.sens_model, sensitivity_map.permute(0, 1, 4, 2, 3)) + .permute(0, 1, 3, 4, 2) + .contiguous() + ) + + def _compute_model_per_coil(self, model, data): + output = [] + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + output.append(model(subselected_data)) + output = torch.stack(output, dim=self._coil_dim) + return output + + def _forward_operator(self, image, sampling_mask, sensitivity_map): + forward = torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=image.dtype).to(image.device), + self.forward_operator(T.expand_operator(image, sensitivity_map, self._coil_dim), dim=self._spatial_dims), + ) + return forward + + def _backward_operator(self, kspace, sampling_mask, sensitivity_map): + backward = T.reduce_operator( + self.backward_operator( + torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), + kspace, + ), + self._spatial_dims, + ), + sensitivity_map, + self._coil_dim, + ) + return backward + + def forward( + self, + masked_kspace: torch.Tensor, + sampling_mask: torch.Tensor, + sensitivity_map: torch.Tensor, + ) -> torch.Tensor: + """ + + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sampling_mask : torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + sensitivity_map : torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). + + Returns + ------- + out_image : torch.Tensor + Output image of shape (N, height, width, complex=2). + """ + + input_image = self._backward_operator(masked_kspace, sampling_mask, sensitivity_map) + input_image = input_image / T.modulus(input_image).unsqueeze(self._coil_dim).amax(dim=self._spatial_dims).view( + -1, 1, 1, 1 + ) + + for iter in range(self.num_iter): + step_sensitivity_map = ( + 2 + * self.lr_sens[iter] + * ( + T.complex_multiplication( + self.backward_operator( + torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), + self._forward_operator(input_image, sampling_mask, sensitivity_map) - masked_kspace, + ), + self._spatial_dims, + ), + T.conjugate(input_image).unsqueeze(self._coil_dim), + ) + + self.reg_param_C[iter] + * ( + sensitivity_map + - self._sens_model(self.backward_operator(masked_kspace, dim=self._spatial_dims)) + ) + ) + ) + sensitivity_map = sensitivity_map - step_sensitivity_map + sensitivity_map_norm = torch.sqrt(((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim)) + sensitivity_map_norm = sensitivity_map_norm.unsqueeze(self._complex_dim).unsqueeze(self._coil_dim) + sensitivity_map = T.safe_divide(sensitivity_map, sensitivity_map_norm) + input_kspace = self.forward_operator(input_image, dim=tuple([d - 1 for d in self._spatial_dims])) + + step_image = ( + 2 + * self.lr_image[iter] + * ( + self._backward_operator( + self._forward_operator(input_image, sampling_mask, sensitivity_map) - masked_kspace, + sampling_mask, + sensitivity_map, + ) + + self.reg_param_I[iter] * (input_image - self._image_model(input_image)) + + self.reg_param_F[iter] + * ( + input_image + - self.backward_operator( + self._kspace_model(input_kspace), dim=tuple([d - 1 for d in self._spatial_dims]) + ) + ) + ) + ) + + input_image = input_image - step_image + input_image = input_image / T.modulus(input_image).unsqueeze(self._coil_dim).amax( + dim=self._spatial_dims + ).view(-1, 1, 1, 1) + + out_image = self.conv_out(input_image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + return out_image diff --git a/direct/nn/jointicnet/jointicnet_engine.py b/direct/nn/jointicnet/jointicnet_engine.py new file mode 100644 index 00000000..ae939983 --- /dev/null +++ b/direct/nn/jointicnet/jointicnet_engine.py @@ -0,0 +1,445 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import time +from collections import defaultdict +from os import PathLike +from typing import Callable, DefaultDict, Dict, List, Optional + +import numpy as np +import torch +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader + +import direct.data.transforms as T +from direct.config import BaseConfig +from direct.engine import DoIterationOutput, Engine +from direct.functionals import SSIMLoss +from direct.utils import ( + communication, + detach_dict, + dict_to_device, + merge_list_of_dicts, + multiply_function, + reduce_list_of_dicts, +) +from direct.utils.communication import reduce_tensor_dict + + +class JointICNetEngine(Engine): + """ + Joint-ICNet Engine. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: int, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def _do_iteration( + self, + data: Dict[str, torch.Tensor], + loss_fns: Optional[Dict[str, Callable]] = None, + regularizer_fns: Optional[Dict[str, Callable]] = None, + ) -> DoIterationOutput: + + # loss_fns can be done, e.g. during validation + if loss_fns is None: + loss_fns = {} + + if regularizer_fns is None: + regularizer_fns = {} + + loss_dicts = [] + regularizer_dicts = [] + + data = dict_to_device(data, self.device) + + # sensitivity_map of shape (batch, coil, height, width, complex=2) + sensitivity_map = data["sensitivity_map"] + + # The sensitivity map needs to be normalized such that + # So \sum_{i \in \text{coils}} S_i S_i^* = 1 + + sensitivity_map_norm = torch.sqrt( + ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim) + ) # shape (batch, height, width) + sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1) + data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm) + + with autocast(enabled=self.mixed_precision): + + output_image = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + ) # shape (batch, height, width, complex=2) + + output_image = T.modulus(output_image) # shape (batch, height, width) + + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + regularizer_dict = { + k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() + } + + for key, value in loss_dict.items(): + loss_dict[key] = value + loss_fns[key]( + output_image, + **data, + reduction="mean", + ) + + for key, value in regularizer_dict.items(): + regularizer_dict[key] = value + regularizer_fns[key]( + output_image, + **data, + ) + + loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dicts.append(detach_dict(loss_dict)) + regularizer_dicts.append( + detach_dict(regularizer_dict) + ) # Need to detach dict as this is only used for logging. + + # Add the loss dicts. + loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") + regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict, **regularizer_dict}, + ) + + def build_loss(self, **kwargs) -> Dict: + # TODO: Cropper is a processing output tool. + def get_resolution(**data): + """Be careful that this will use the cropping size of the FIRST sample in the batch.""" + return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None)) + + def l1_loss(source, reduction="mean", **data): + """ + Calculate L1 loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l1_loss + + def l2_loss(source, reduction="mean", **data): + """ + Calculate L2 loss (MSE) given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l2_loss + + def ssim_loss(source, reduction="mean", **data): + """ + Calculate SSIM loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + if reduction != "mean": + raise AssertionError( + f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}." + ) + + source_abs, target_abs = self.cropper(source, data["target"], resolution) + data_range = torch.tensor([target_abs.max()], device=target_abs.device) + + ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range) + + return ssim_loss + + # Build losses + loss_dict = {} + for curr_loss in self.cfg.training.loss.losses: # type: ignore + loss_fn = curr_loss.function + if loss_fn == "l1_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss) + elif loss_fn == "l2_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss) + elif loss_fn == "ssim_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss) + else: + raise ValueError(f"{loss_fn} not permissible.") + + return loss_dict + + @torch.no_grad() + def evaluate( + self, + data_loader: DataLoader, + loss_fns: Optional[Dict[str, Callable]], + regularizer_fns: Optional[Dict[str, Callable]] = None, + crop: Optional[str] = None, + is_validation_process: bool = True, + ): + """ + Validation process. Assumes that each batch only contains slices of the same volume *AND* that these + are sequentially ordered. + + Parameters + ---------- + data_loader : DataLoader + loss_fns : Dict[str, Callable], optional + regularizer_fns : Dict[str, Callable], optional + crop : str, optional + is_validation_process : bool + + Returns + ------- + loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + """ + self.models_to_device() + self.models_validation_mode() + torch.cuda.empty_cache() + + # Variables required for evaluation. + volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore + + # filenames can be in the volume_indices attribute of the dataset + num_for_this_process = None + all_filenames = None + if hasattr(data_loader.dataset, "volume_indices"): + all_filenames = list(data_loader.dataset.volume_indices.keys()) + num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys())) + self.logger.info( + f"Reconstructing a total of {len(all_filenames)} volumes. " + f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})." + ) + + filenames_seen = 0 + reconstruction_output: DefaultDict = defaultdict(list) + if is_validation_process: + targets_output: DefaultDict = defaultdict(list) + val_losses = [] + val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict) + last_filename = None + + # Container to for the slices which can be visualized in TensorBoard. + visualize_slices: List[np.ndarray] = [] + visualize_target: List[np.ndarray] = [] + # visualizations = {} + + extra_visualization_keys = ( + self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore + ) + + # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler + # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is + # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only + # contains data from one volume. + time_start = time.time() + + for iter_idx, data in enumerate(data_loader): + filenames = data.pop("filename") + if len(set(filenames)) != 1: + raise ValueError( + f"Expected a batch during validation to only contain filenames of one case. " + f"Got {set(filenames)}." + ) + + slice_nos = data.pop("slice_no") + scaling_factors = data["scaling_factor"] + + resolution = self.compute_resolution( + key=self.cfg.validation.crop, # type: ignore + reconstruction_size=data.get("reconstruction_size", None), + ) + + # Compute output and loss. + iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns) + output = iteration_output.output_image + loss_dict = iteration_output.data_dict + + loss_dict = detach_dict(loss_dict) + output = output.detach() + val_losses.append(loss_dict) + + # Output is complex-valued, and has to be cropped. This holds for both output and target. + # Output has shape (batch, complex, height, width) + output_abs = self.process_output( + output, + scaling_factors, + resolution=resolution, + ) + + if is_validation_process: + # Target has shape (batch, height, width) + target_abs = self.process_output( + data["target"].detach(), + scaling_factors, + resolution=resolution, + ) + for key in extra_visualization_keys: + curr_data = data[key].detach() + # Here we need to discover which keys are actually normalized or not + # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23 + + del output # Explicitly call delete to clear memory. + + # Aggregate volumes to be able to compute the metrics on complete volumes. + for idx, filename in enumerate(filenames): + if last_filename is None: + last_filename = filename # First iteration last_filename is not set. + + curr_slice = output_abs[idx].detach() + slice_no = int(slice_nos[idx].numpy()) + + reconstruction_output[filename].append((slice_no, curr_slice.cpu())) + + if is_validation_process: + targets_output[filename].append((slice_no, target_abs[idx].cpu())) + + is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"]) + reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch] + for condition in reconstruction_conditions: + if condition: + filenames_seen += 1 + + # Now we can ditch the reconstruction dict by reconstructing the volume, + # will take too much memory otherwise. + volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]]) + if is_validation_process: + target = torch.stack([_[1] for _ in targets_output[last_filename]]) + curr_metrics = { + metric_name: metric_fn(target, volume) + for metric_name, metric_fn in volume_metrics.items() + } + val_volume_metrics[last_filename] = curr_metrics + # Log the center slice of the volume + if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore + visualize_slices.append(volume[volume.shape[0] // 2]) + visualize_target.append(target[target.shape[0] // 2]) + + # Delete outputs from memory, and recreate dictionary. + # This is not needed when not in validation as we are actually interested + # in the iteration output. + del targets_output[last_filename] + del reconstruction_output[last_filename] + + if all_filenames: + log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:" + else: + log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:" + + self.logger.info( + f"{log_prefix} {last_filename}" + f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s." + ) + # restart timer + time_start = time.time() + last_filename = filename + + # Average loss dict + loss_dict = reduce_list_of_dicts(val_losses) + reduce_tensor_dict(loss_dict) + + communication.synchronize() + torch.cuda.empty_cache() + + all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics)) + if not is_validation_process: + return loss_dict, reconstruction_output + + return loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + def process_output(self, data, scaling_factors=None, resolution=None): + # data is of shape (batch, complex=2, height, width) + if scaling_factors is not None: + data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device) + + data = T.modulus_if_complex(data) + + if len(data.shape) == 3: # (batch, height, width) + data = data.unsqueeze(1) # Added channel dimension. + + if resolution is not None: + data = T.center_crop(data, resolution).contiguous() + + return data + + @staticmethod + def compute_resolution(key, reconstruction_size): + if key == "header": + # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over + # batches. + resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size] + # The volume sampler should give validation indices belonging to the *same* volume, so it should be + # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y). + resolution = [_[0] for _ in resolution][:-1] + elif key == "training": + resolution = key + elif not key: + resolution = None + else: + raise ValueError( + "Cropping should be either set to `header` to get the values from the header or " + "`training` to take the same value as training." + ) + return resolution + + def cropper(self, source, target, resolution): + """ + 2D source/target cropper + + Parameters: + ----------- + Source has shape (batch, height, width) + Target has shape (batch, height, width) + + """ + + if not resolution or all(_ == 0 for _ in resolution): + return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension. + + source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension. + target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension. + + return source_abs, target_abs diff --git a/direct/nn/jointicnet/tests/__init__.py b/direct/nn/jointicnet/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/jointicnet/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/jointicnet/tests/test_jointicnet.py b/direct/nn/jointicnet/tests/test_jointicnet.py new file mode 100644 index 00000000..f68fb3cf --- /dev/null +++ b/direct/nn/jointicnet/tests/test_jointicnet.py @@ -0,0 +1,50 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch + +from direct.data.transforms import fft2, ifft2 +from direct.nn.jointicnet.jointicnet import JointICNet + + +def create_input(shape): + + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [3, 3, 16, 16], + [2, 5, 16, 32], + ], +) +@pytest.mark.parametrize( + "num_iter", + [2, 4], +) +@pytest.mark.parametrize( + "use_norm_unet", + [True, False], +) +def test_jointicnet(shape, num_iter, use_norm_unet): + model = JointICNet( + fft2, + ifft2, + num_iter, + use_norm_unet, + image_unet_num_pool_layers=2, + kspace_unet_num_pool_layers=2, + sens_unet_num_pool_layers=2, + ).cpu() + + kspace = create_input(shape + [2]).cpu() + mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu() + sens = create_input(shape + [2]).cpu() + + out = model(kspace, mask, sens) + + assert list(out.shape) == [shape[0]] + shape[2:] + [2] diff --git a/direct/nn/kikinet/__init__.py b/direct/nn/kikinet/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/kikinet/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/kikinet/config.py b/direct/nn/kikinet/config.py new file mode 100644 index 00000000..4f54fb67 --- /dev/null +++ b/direct/nn/kikinet/config.py @@ -0,0 +1,30 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from dataclasses import dataclass + +from direct.config.defaults import ModelConfig + + +@dataclass +class KIKINetConfig(ModelConfig): + num_iter: int = 10 + image_model_architecture: str = "MWCNN" + kspace_model_architecture: str = "UNET" + image_mwcnn_hidden_channels: int = 16 + image_mwcnn_num_scales: int = 4 + image_mwcnn_bias: bool = True + image_mwcnn_batchnorm: bool = False + image_unet_num_filters: int = 8 + image_unet_num_pool_layers: int = 4 + image_unet_dropout_probability: float = 0.0 + kspace_conv_hidden_channels: int = 16 + kspace_conv_n_convs: int = 4 + kspace_conv_batchnorm: bool = False + kspace_didn_hidden_channels: int = 64 + kspace_didn_num_dubs: int = 6 + kspace_didn_num_convs_recon: int = 9 + kspace_unet_num_filters: int = 8 + kspace_unet_num_pool_layers: int = 4 + kspace_unet_dropout_probability: float = 0.0 + normalize: bool = False diff --git a/direct/nn/kikinet/kikinet.py b/direct/nn/kikinet/kikinet.py new file mode 100644 index 00000000..712d74ce --- /dev/null +++ b/direct/nn/kikinet/kikinet.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from typing import Callable, Optional + +import torch +import torch.nn as nn + +import direct.data.transforms as T +from direct.nn.conv.conv import Conv2d +from direct.nn.didn.didn import DIDN +from direct.nn.mwcnn.mwcnn import MWCNN +from direct.nn.crossdomain.multicoil import MultiCoil +from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d + + +class KIKINet(nn.Module): + """ + Based on KIKINet implementation as in "KIKI-net: cross-domain convolutional neural networks for + reconstructing undersampled magnetic resonance images" by Taejoon Eo et all. Modified to work with + multicoil kspace data. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + image_model_architecture: str = "MWCNN", + kspace_model_architecture: str = "DIDN", + num_iter: int = 2, + normalize: bool = False, + **kwargs, + ): + """ + + Parameters + ---------- + forward_operator : Callable + Forward Operator. + backward_operator : Callable + Backward Operator. + image_model_architecture : str + Image model architecture. Currently only implemented for MWCNN and (NORM)UNET. Default: 'MWCNN'. + kspace_model_architecture : str + Kspace model architecture. Currently only implemented for CONV and DIDN and (NORM)UNET. Default: 'DIDN'. + num_iter : int + Number of unrolled iterations. + normalize : bool + If true, input is normalised based on input scaling_factor. + kwargs : dict + Keyword arguments for model architectures. + """ + super().__init__() + + if image_model_architecture == "MWCNN": + image_model = MWCNN( + input_channels=2, + first_conv_hidden_channels=kwargs.get("image_mwcnn_hidden_channels", 32), + num_scales=kwargs.get("image_mwcnn_num_scales", 4), + bias=kwargs.get("image_mwcnn_bias", False), + batchnorm=kwargs.get("image_mwcnn_batchnorm", False), + ) + elif image_model_architecture in ["UNET", "NORMUNET"]: + unet = UnetModel2d if image_model_architecture == "UNET" else NormUnetModel2d + image_model = unet( + in_channels=2, + out_channels=2, + num_filters=kwargs.get("image_unet_num_filters", 8), + num_pool_layers=kwargs.get("image_unet_num_pool_layers", 4), + dropout_probability=kwargs.get("image_unet_dropout_probability", 0.0), + ) + else: + raise NotImplementedError( + f"XPDNet is currently implemented only with image_model_architecture == 'MWCNN', 'UNET' or 'NORMUNET." + f"Got {image_model_architecture}." + ) + + if kspace_model_architecture == "CONV": + kspace_model = Conv2d( + in_channels=2, + out_channels=2, + hidden_channels=kwargs.get("kspace_conv_hidden_channels", 16), + n_convs=kwargs.get("kspace_conv_n_convs", 4), + batchnorm=kwargs.get("kspace_conv_batchnorm", False), + ) + elif kspace_model_architecture == "DIDN": + kspace_model = DIDN( + in_channels=2, + out_channels=2, + hidden_channels=kwargs.get("kspace_didn_hidden_channels", 16), + num_dubs=kwargs.get("kspace_didn_num_dubs", 6), + num_convs_recon=kwargs.get("kspace_didn_num_convs_recon", 9), + ) + elif kspace_model_architecture in ["UNET", "NORMUNET"]: + unet = UnetModel2d if kspace_model_architecture == "UNET" else NormUnetModel2d + kspace_model = unet( + in_channels=2, + out_channels=2, + num_filters=kwargs.get("kspace_unet_num_filters", 8), + num_pool_layers=kwargs.get("kspace_unet_num_pool_layers", 4), + dropout_probability=kwargs.get("kspace_unet_dropout_probability", 0.0), + ) + else: + raise NotImplementedError( + f"XPDNet is currently implemented for kspace_model_architecture == 'CONV', 'DIDN'," + f" 'UNET' or 'NORMUNET'. Got kspace_model_architecture == {kspace_model_architecture}." + ) + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + self.image_model_list = nn.ModuleList([image_model] * num_iter) + self.kspace_model_list = nn.ModuleList([MultiCoil(kspace_model, self._coil_dim)] * num_iter) + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.num_iter = num_iter + self.normalize = normalize + + def forward( + self, + masked_kspace: torch.Tensor, + sampling_mask: torch.Tensor, + sensitivity_map: torch.Tensor, + scaling_factor: Optional[torch.Tensor] = None, + ): + """ + + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sampling_mask : torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + sensitivity_map : torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). + scaling_factor : Optional[torch.Tensor] + Scaling factor of shape (N,). If None, no scaling is applied. Default: None. + + Returns + ------- + out_image : torch.Tensor + Output image of shape (N, height, width, complex=2). + """ + + kspace = masked_kspace.clone() + if self.normalize and scaling_factor is not None: + kspace = kspace / (scaling_factor ** 2).view(-1, 1, 1, 1, 1) + + for idx in range(self.num_iter): + kspace = self.kspace_model_list[idx](kspace.permute(0, 1, 4, 2, 3)).permute(0, 1, 3, 4, 2) + + image = T.reduce_operator( + self.backward_operator( + torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), + kspace, + ), + self._spatial_dims, + ), + sensitivity_map, + self._coil_dim, + ) + + image = self.image_model_list[idx](image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + if idx < self.num_iter - 1: + kspace = torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=image.dtype).to(image.device), + self.forward_operator( + T.expand_operator(image, sensitivity_map, self._coil_dim), dim=self._spatial_dims + ), + ) + + if self.normalize and scaling_factor is not None: + image = image * (scaling_factor ** 2).view(-1, 1, 1, 1) + + return image diff --git a/direct/nn/kikinet/kikinet_engine.py b/direct/nn/kikinet/kikinet_engine.py new file mode 100644 index 00000000..976524f4 --- /dev/null +++ b/direct/nn/kikinet/kikinet_engine.py @@ -0,0 +1,470 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import time +from collections import defaultdict +from os import PathLike +from typing import Callable, DefaultDict, Dict, List, Optional + +import numpy as np +import torch +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader + +import direct.data.transforms as T +from direct.config import BaseConfig +from direct.engine import DoIterationOutput, Engine +from direct.functionals import SSIMLoss +from direct.utils import ( + communication, + detach_dict, + dict_to_device, + merge_list_of_dicts, + multiply_function, + reduce_list_of_dicts, +) +from direct.utils.communication import reduce_tensor_dict + + +class KIKINetEngine(Engine): + """ + XPDNet Engine. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: int, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._complex_dim = -1 + self._coil_dim = 1 + self._spatial_dims = (2, 3) + + def _do_iteration( + self, + data: Dict[str, torch.Tensor], + loss_fns: Optional[Dict[str, Callable]] = None, + regularizer_fns: Optional[Dict[str, Callable]] = None, + ) -> DoIterationOutput: + + # loss_fns can be done, e.g. during validation + if loss_fns is None: + loss_fns = {} + + if regularizer_fns is None: + regularizer_fns = {} + + loss_dicts = [] + regularizer_dicts = [] + + data = dict_to_device(data, self.device) + + # sensitivity_map of shape (batch, coil, height, width, complex=2) + sensitivity_map = data["sensitivity_map"] + + if "sensitivity_model" in self.models: + + # Move channels to first axis + sensitivity_map = data["sensitivity_map"].permute((0, 1, 4, 2, 3)) + + sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute( + (0, 1, 3, 4, 2) + ) + + # The sensitivity map needs to be normalized such that + # So \sum_{i \in \text{coils}} S_i S_i^* = 1 + + sensitivity_map_norm = torch.sqrt( + ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim) + ) # shape (batch, height, width) + sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1) + data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm) + + with autocast(enabled=self.mixed_precision): + + output_image = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + scaling_factor=data["scaling_factor"], + ) # shape (batch, height, width, complex=2) + + output_image = T.modulus(output_image) # shape (batch, height, width) + + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + regularizer_dict = { + k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() + } + + for key, value in loss_dict.items(): + loss_dict[key] = value + loss_fns[key]( + output_image, + **data, + reduction="mean", + ) + + for key, value in regularizer_dict.items(): + regularizer_dict[key] = value + regularizer_fns[key]( + output_image, + **data, + ) + + loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dicts.append(detach_dict(loss_dict)) + regularizer_dicts.append( + detach_dict(regularizer_dict) + ) # Need to detach dict as this is only used for logging. + + # Add the loss dicts. + loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") + regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict, **regularizer_dict}, + ) + + def build_loss(self, **kwargs) -> Dict: + # TODO: Cropper is a processing output tool. + def get_resolution(**data): + """Be careful that this will use the cropping size of the FIRST sample in the batch.""" + return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None)) + + def l1_loss(source, reduction="mean", **data): + """ + Calculate L1 loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l1_loss + + def l2_loss(source, reduction="mean", **data): + """ + Calculate L2 loss (MSE) given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l2_loss + + def ssim_loss(source, reduction="mean", **data): + """ + Calculate SSIM loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + if reduction != "mean": + raise AssertionError( + f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}." + ) + + source_abs, target_abs = self.cropper(source, data["target"], resolution) + data_range = torch.tensor([target_abs.max()], device=target_abs.device) + + ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range) + + return ssim_loss + + # Build losses + loss_dict = {} + for curr_loss in self.cfg.training.loss.losses: # type: ignore + loss_fn = curr_loss.function + if loss_fn == "l1_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss) + elif loss_fn == "l2_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss) + elif loss_fn == "ssim_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss) + else: + raise ValueError(f"{loss_fn} not permissible.") + + return loss_dict + + @torch.no_grad() + def evaluate( + self, + data_loader: DataLoader, + loss_fns: Optional[Dict[str, Callable]], + regularizer_fns: Optional[Dict[str, Callable]] = None, + crop: Optional[str] = None, + is_validation_process: bool = True, + ): + """ + Validation process. Assumes that each batch only contains slices of the same volume *AND* that these + are sequentially ordered. + + Parameters + ---------- + data_loader : DataLoader + loss_fns : Dict[str, Callable], optional + regularizer_fns : Dict[str, Callable], optional + crop : str, optional + is_validation_process : bool + + Returns + ------- + loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + """ + self.models_to_device() + self.models_validation_mode() + torch.cuda.empty_cache() + + # Variables required for evaluation. + volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore + + # filenames can be in the volume_indices attribute of the dataset + num_for_this_process = None + all_filenames = None + if hasattr(data_loader.dataset, "volume_indices"): + all_filenames = list(data_loader.dataset.volume_indices.keys()) + num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys())) + self.logger.info( + f"Reconstructing a total of {len(all_filenames)} volumes. " + f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})." + ) + + filenames_seen = 0 + reconstruction_output: DefaultDict = defaultdict(list) + if is_validation_process: + targets_output: DefaultDict = defaultdict(list) + val_losses = [] + val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict) + last_filename = None + + # Container to for the slices which can be visualized in TensorBoard. + visualize_slices: List[np.ndarray] = [] + visualize_target: List[np.ndarray] = [] + # visualizations = {} + + extra_visualization_keys = ( + self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore + ) + + # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler + # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is + # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only + # contains data from one volume. + time_start = time.time() + + for iter_idx, data in enumerate(data_loader): + filenames = data.pop("filename") + if len(set(filenames)) != 1: + raise ValueError( + f"Expected a batch during validation to only contain filenames of one case. " + f"Got {set(filenames)}." + ) + + slice_nos = data.pop("slice_no") + scaling_factors = data["scaling_factor"] + + resolution = self.compute_resolution( + key=self.cfg.validation.crop, # type: ignore + reconstruction_size=data.get("reconstruction_size", None), + ) + + # Compute output and loss. + iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns) + output = iteration_output.output_image + loss_dict = iteration_output.data_dict + + loss_dict = detach_dict(loss_dict) + output = output.detach() + val_losses.append(loss_dict) + + # Output is complex-valued, and has to be cropped. This holds for both output and target. + # Output has shape (batch, complex, height, width) + output_abs = self.process_output( + output, + scaling_factors, + resolution=resolution, + ) + + if is_validation_process: + # Target has shape (batch, height, width) + target_abs = self.process_output( + data["target"].detach(), + scaling_factors, + resolution=resolution, + ) + for key in extra_visualization_keys: + curr_data = data[key].detach() + # Here we need to discover which keys are actually normalized or not + # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23 + + del output # Explicitly call delete to clear memory. + + # Aggregate volumes to be able to compute the metrics on complete volumes. + for idx, filename in enumerate(filenames): + if last_filename is None: + last_filename = filename # First iteration last_filename is not set. + + curr_slice = output_abs[idx].detach() + slice_no = int(slice_nos[idx].numpy()) + + reconstruction_output[filename].append((slice_no, curr_slice.cpu())) + + if is_validation_process: + targets_output[filename].append((slice_no, target_abs[idx].cpu())) + + is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"]) + reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch] + for condition in reconstruction_conditions: + if condition: + filenames_seen += 1 + + # Now we can ditch the reconstruction dict by reconstructing the volume, + # will take too much memory otherwise. + volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]]) + if is_validation_process: + target = torch.stack([_[1] for _ in targets_output[last_filename]]) + curr_metrics = { + metric_name: metric_fn(target, volume) + for metric_name, metric_fn in volume_metrics.items() + } + val_volume_metrics[last_filename] = curr_metrics + # Log the center slice of the volume + if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore + visualize_slices.append(volume[volume.shape[0] // 2]) + visualize_target.append(target[target.shape[0] // 2]) + + # Delete outputs from memory, and recreate dictionary. + # This is not needed when not in validation as we are actually interested + # in the iteration output. + del targets_output[last_filename] + del reconstruction_output[last_filename] + + if all_filenames: + log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:" + else: + log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:" + + self.logger.info( + f"{log_prefix} {last_filename}" + f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s." + ) + # restart timer + time_start = time.time() + last_filename = filename + + # Average loss dict + loss_dict = reduce_list_of_dicts(val_losses) + reduce_tensor_dict(loss_dict) + + communication.synchronize() + torch.cuda.empty_cache() + + all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics)) + if not is_validation_process: + return loss_dict, reconstruction_output + + return loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + def process_output(self, data, scaling_factors=None, resolution=None): + # data is of shape (batch, complex=2, height, width) + if scaling_factors is not None: + data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device) + + data = T.modulus_if_complex(data) + + if len(data.shape) == 3: # (batch, height, width) + data = data.unsqueeze(1) # Added channel dimension. + + if resolution is not None: + data = T.center_crop(data, resolution).contiguous() + + return data + + @staticmethod + def compute_resolution(key, reconstruction_size): + if key == "header": + # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over + # batches. + resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size] + # The volume sampler should give validation indices belonging to the *same* volume, so it should be + # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y). + resolution = [_[0] for _ in resolution][:-1] + elif key == "training": + resolution = key + elif not key: + resolution = None + else: + raise ValueError( + "Cropping should be either set to `header` to get the values from the header or " + "`training` to take the same value as training." + ) + return resolution + + def cropper(self, source, target, resolution): + """ + 2D source/target cropper + + Parameters: + ----------- + Source has shape (batch, height, width) + Target has shape (batch, height, width) + + """ + + if not resolution or all(_ == 0 for _ in resolution): + return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension. + + source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension. + target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension. + + return source_abs, target_abs + + def compute_model_per_coil(self, model_name, data): + """ + Computes model per coil. + """ + # data is of shape (batch, coil, complex=2, height, width) + output = [] + + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + output.append(self.models[model_name](subselected_data)) + output = torch.stack(output, dim=self._coil_dim) + + # output is of shape (batch, coil, complex=2, height, width) + return output diff --git a/direct/nn/kikinet/tests/__init__.py b/direct/nn/kikinet/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/kikinet/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/kikinet/tests/test_kikinet.py b/direct/nn/kikinet/tests/test_kikinet.py new file mode 100644 index 00000000..9a7e70e9 --- /dev/null +++ b/direct/nn/kikinet/tests/test_kikinet.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch + +from direct.data.transforms import fft2, ifft2 +from direct.nn.kikinet.kikinet import KIKINet + + +def create_input(shape): + + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [3, 3, 32, 32], + ], +) +@pytest.mark.parametrize( + "num_iter", + [1, 3], +) +@pytest.mark.parametrize( + "image_model_architecture", + ["MWCNN", "UNET", "NORMUNET"], +) +@pytest.mark.parametrize( + "kspace_model_architecture", + ["CONV", "DIDN", "UNET", "NORMUNET"], +) +@pytest.mark.parametrize( + "normalize", + [True, False], +) +def test_kikinet(shape, num_iter, image_model_architecture, kspace_model_architecture, normalize): + model = KIKINet( + fft2, + ifft2, + num_iter=num_iter, + image_model_architecture=image_model_architecture, + kspace_model_architecture=kspace_model_architecture, + normalize=normalize, + ).cpu() + + kspace = create_input(shape + [2]).cpu() + mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu() + sens = create_input(shape + [2]).cpu() + + out = model(kspace, mask, sens) + + assert list(out.shape) == [shape[0]] + shape[2:] + [2] diff --git a/direct/nn/lpd/config.py b/direct/nn/lpd/config.py new file mode 100644 index 00000000..aab98eb0 --- /dev/null +++ b/direct/nn/lpd/config.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from dataclasses import dataclass + +from direct.config.defaults import ModelConfig + + +@dataclass +class LPDNetConfig(ModelConfig): + num_iter: int = 25 + num_primal: int = 5 + num_dual: int = 5 + primal_model_architecture: str = "MWCNN" + dual_model_architecture: str = "DIDN" + primal_mwcnn_hidden_channels: int = 16 + primal_mwcnn_num_scales: int = 4 + primal_mwcnn_bias: bool = True + primal_mwcnn_batchnorm: bool = False + primal_unet_num_filters: int = 8 + primal_unet_num_pool_layers: int = 4 + primal_unet_dropout_probability: float = 0.0 + dual_conv_hidden_channels: int = 16 + dual_conv_n_convs: int = 4 + dual_conv_batchnorm: bool = False + dual_didn_hidden_channels: int = 64 + dual_didn_num_dubs: int = 6 + dual_didn_num_convs_recon: int = 9 + dual_unet_num_filters: int = 8 + dual_unet_num_pool_layers: int = 4 + dual_unet_dropout_probability: float = 0.0 diff --git a/direct/nn/lpd/lpd.py b/direct/nn/lpd/lpd.py index 941752c9..3826408e 100644 --- a/direct/nn/lpd/lpd.py +++ b/direct/nn/lpd/lpd.py @@ -1,2 +1,263 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors + +from typing import Callable + +import direct.data.transforms as T +from direct.nn.conv.conv import Conv2d +from direct.nn.didn.didn import DIDN +from direct.nn.mwcnn.mwcnn import MWCNN +from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d + +import torch +import torch.nn as nn + + +class DualNet(nn.Module): + """ + Dual Network for Learned Primal Dual Network. + """ + + def __init__(self, num_dual, **kwargs): + super().__init__() + + if kwargs.get("dual_architectue") is None: + n_hidden = kwargs.get("n_hidden") + if n_hidden is None: + raise ValueError("Missing argument n_hidden.") + self.dual_block = nn.Sequential( + *[ + nn.Conv2d(2 * (num_dual + 2), n_hidden, kernel_size=3, padding=1), + nn.PReLU(), + nn.Conv2d(n_hidden, n_hidden, kernel_size=3, padding=1), + nn.PReLU(), + nn.Conv2d(n_hidden, 2 * num_dual, kernel_size=3, padding=1), + ] + ) + else: + self.dual_block = kwargs.get("dual_architectue") + + @staticmethod + def compute_model_per_coil(model, data): + """ + Computes model per coil. + """ + output = [] + for idx in range(data.size(1)): + subselected_data = data.select(1, idx) + output.append(model(subselected_data)) + output = torch.stack(output, dim=1) + return output + + def forward(self, h, forward_f, g): + inp = torch.cat([h, forward_f, g], dim=-1).permute(0, 1, 4, 2, 3) + return self.compute_model_per_coil(self.dual_block, inp).permute(0, 1, 3, 4, 2) + + +class PrimalNet(nn.Module): + """ + Primal Network for Learned Primal Dual Network. + """ + + def __init__(self, num_primal, **kwargs): + super().__init__() + + if kwargs.get("primal_architectue") is None: + n_hidden = kwargs.get("n_hidden") + if n_hidden is None: + raise ValueError("Missing argument n_hidden.") + self.primal_block = nn.Sequential( + *[ + nn.Conv2d(2 * (num_primal + 1), n_hidden, kernel_size=3, padding=1), + nn.PReLU(), + nn.Conv2d(n_hidden, n_hidden, kernel_size=3, padding=1), + nn.PReLU(), + nn.Conv2d(n_hidden, 2 * num_primal, kernel_size=3, padding=1), + ] + ) + else: + self.primal_block = kwargs.get("primal_architectue") + + def forward(self, f, backward_h): + inp = torch.cat([f, backward_h], dim=-1).permute(0, 3, 1, 2) + return self.primal_block(inp).permute(0, 2, 3, 1) + + +class LPDNet(nn.Module): + """ + Learned Primal Dual network implementation as in https://arxiv.org/abs/1707.06474. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + num_iter: int, + num_primal: int, + num_dual: int, + primal_model_architecture: str = "MWCNN", + dual_model_architecture: str = "DIDN", + **kwargs, + ): + """ + + Parameters + ---------- + forward_operator : Callable + Forward Operator. + backward_operator : Callable + Backward Operator. + num_iter : int + Number of unrolled iterations. + num_primal : int + Number of primal networks. + num_dual : int + Number of dual networks. + primal_model_architecture : str + Primal model architecture. Currently only implemented for MWCNN and (NORM)UNET. Default: 'MWCNN'. + dual_model_architecture : str + Dual model architecture. Currently only implemented for CONV and DIDN and (NORM)UNET. Default: 'DIDN'. + kwargs : dict + Keyword arguments for model architectures. + """ + super().__init__() + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.num_iter = num_iter + self.num_primal = num_primal + self.num_dual = num_dual + + if primal_model_architecture == "MWCNN": + primal_model = nn.Sequential( + *[ + MWCNN( + input_channels=2 * (num_primal + 1), + first_conv_hidden_channels=kwargs.get("primal_mwcnn_hidden_channels", 32), + num_scales=kwargs.get("primal_mwcnn_num_scales", 4), + bias=kwargs.get("primal_mwcnn_bias", False), + batchnorm=kwargs.get("primal_mwcnn_batchnorm", False), + ), + nn.Conv2d(2 * (num_primal + 1), 2 * num_primal, kernel_size=1), + ] + ) + elif primal_model_architecture in ["UNET", "NORMUNET"]: + unet = UnetModel2d if primal_model_architecture == "UNET" else NormUnetModel2d + primal_model = unet( + in_channels=2 * (num_primal + 1), + out_channels=2 * num_primal, + num_filters=kwargs.get("primal_unet_num_filters", 8), + num_pool_layers=kwargs.get("primal_unet_num_pool_layers", 4), + dropout_probability=kwargs.get("primal_unet_dropout_probability", 0.0), + ) + else: + raise NotImplementedError( + f"XPDNet is currently implemented only with primal_model_architecture == 'MWCNN', 'UNET' or 'NORMUNET." + f"Got {primal_model_architecture}." + ) + + if dual_model_architecture == "CONV": + dual_model = Conv2d( + in_channels=2 * (num_dual + 2), + out_channels=2 * num_dual, + hidden_channels=kwargs.get("dual_conv_hidden_channels", 16), + n_convs=kwargs.get("dual_conv_n_convs", 4), + batchnorm=kwargs.get("dual_conv_batchnorm", False), + ) + elif dual_model_architecture == "DIDN": + dual_model = DIDN( + in_channels=2 * (num_dual + 2), + out_channels=2 * num_dual, + hidden_channels=kwargs.get("dual_didn_hidden_channels", 16), + num_dubs=kwargs.get("dual_didn_num_dubs", 6), + num_convs_recon=kwargs.get("dual_didn_num_convs_recon", 9), + ) + elif dual_model_architecture in ["UNET", "NORMUNET"]: + unet = UnetModel2d if dual_model_architecture == "UNET" else NormUnetModel2d + dual_model = unet( + in_channels=2 * (num_dual + 2), + out_channels=2 * num_dual, + num_filters=kwargs.get("dual_unet_num_filters", 8), + num_pool_layers=kwargs.get("dual_unet_num_pool_layers", 4), + dropout_probability=kwargs.get("dual_unet_dropout_probability", 0.0), + ) + else: + raise NotImplementedError( + f"XPDNet is currently implemented for dual_model_architecture == 'CONV', 'DIDN'," + f" 'UNET' or 'NORMUNET'. Got dual_model_architecture == {dual_model_architecture}." + ) + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + self.primal_net = nn.ModuleList( + [PrimalNet(num_primal, primal_architectue=primal_model) for _ in range(num_iter)] + ) + self.dual_net = nn.ModuleList([DualNet(num_dual, dual_architectue=dual_model) for _ in range(num_iter)]) + + def _forward_operator(self, image, sampling_mask, sensitivity_map): + forward = torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=image.dtype).to(image.device), + self.forward_operator(T.expand_operator(image, sensitivity_map, self._coil_dim), dim=self._spatial_dims), + ) + return forward + + def _backward_operator(self, kspace, sampling_mask, sensitivity_map): + backward = T.reduce_operator( + self.backward_operator( + torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), + kspace, + ), + self._spatial_dims, + ), + sensitivity_map, + self._coil_dim, + ) + return backward + + def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: torch.Tensor, + sampling_mask: torch.Tensor, + ) -> torch.Tensor: + """ + + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map : torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). + sampling_mask : torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + + Returns + ------- + output : torch.Tensor + Output image of shape (N, height, width, complex=2). + """ + input_image = self._backward_operator(masked_kspace, sampling_mask, sensitivity_map) + dual_buffer = torch.cat([masked_kspace] * self.num_dual, self._complex_dim).to(masked_kspace.device) + primal_buffer = torch.cat([input_image] * self.num_primal, self._complex_dim).to(masked_kspace.device) + + for iter in range(self.num_iter): + + # Dual + f_2 = primal_buffer[..., 2:4].clone() + dual_buffer = self.dual_net[iter]( + dual_buffer, self._forward_operator(f_2, sampling_mask, sensitivity_map), masked_kspace + ) + + # Primal + h_1 = dual_buffer[..., 0:2].clone() + primal_buffer = self.primal_net[iter]( + primal_buffer, self._backward_operator(h_1, sampling_mask, sensitivity_map) + ) + + output = primal_buffer[..., 0:2] + return output diff --git a/direct/nn/lpd/lpd_engine.py b/direct/nn/lpd/lpd_engine.py new file mode 100644 index 00000000..50c72874 --- /dev/null +++ b/direct/nn/lpd/lpd_engine.py @@ -0,0 +1,469 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import time +from collections import defaultdict +from os import PathLike +from typing import Callable, DefaultDict, Dict, List, Optional + +import numpy as np +import torch +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader + +import direct.data.transforms as T +from direct.config import BaseConfig +from direct.engine import DoIterationOutput, Engine +from direct.functionals import SSIMLoss +from direct.utils import ( + communication, + detach_dict, + dict_to_device, + merge_list_of_dicts, + multiply_function, + reduce_list_of_dicts, +) +from direct.utils.communication import reduce_tensor_dict + + +class LPDNetEngine(Engine): + """ + LPDNet Engine. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: int, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._complex_dim = -1 + self._coil_dim = 1 + self._spatial_dims = (2, 3) + + def _do_iteration( + self, + data: Dict[str, torch.Tensor], + loss_fns: Optional[Dict[str, Callable]] = None, + regularizer_fns: Optional[Dict[str, Callable]] = None, + ) -> DoIterationOutput: + + # loss_fns can be done, e.g. during validation + if loss_fns is None: + loss_fns = {} + + if regularizer_fns is None: + regularizer_fns = {} + + loss_dicts = [] + regularizer_dicts = [] + + data = dict_to_device(data, self.device) + + # sensitivity_map of shape (batch, coil, height, width, complex=2) + sensitivity_map = data["sensitivity_map"] + + if "sensitivity_model" in self.models: + + # Move channels to first axis + sensitivity_map = data["sensitivity_map"].permute((0, 1, 4, 2, 3)) + + sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute( + (0, 1, 3, 4, 2) + ) + + # The sensitivity map needs to be normalized such that + # So \sum_{i \in \text{coils}} S_i S_i^* = 1 + + sensitivity_map_norm = torch.sqrt( + ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim) + ) # shape (batch, height, width) + sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1) + data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm) + + with autocast(enabled=self.mixed_precision): + + output_image = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + ) # shape (batch, height, width, complex=2) + + output_image = T.modulus(output_image) # shape (batch, height, width) + + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + regularizer_dict = { + k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() + } + + for key, value in loss_dict.items(): + loss_dict[key] = value + loss_fns[key]( + output_image, + **data, + reduction="mean", + ) + + for key, value in regularizer_dict.items(): + regularizer_dict[key] = value + regularizer_fns[key]( + output_image, + **data, + ) + + loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dicts.append(detach_dict(loss_dict)) + regularizer_dicts.append( + detach_dict(regularizer_dict) + ) # Need to detach dict as this is only used for logging. + + # Add the loss dicts. + loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") + regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict, **regularizer_dict}, + ) + + def build_loss(self, **kwargs) -> Dict: + # TODO: Cropper is a processing output tool. + def get_resolution(**data): + """Be careful that this will use the cropping size of the FIRST sample in the batch.""" + return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None)) + + def l1_loss(source, reduction="mean", **data): + """ + Calculate L1 loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l1_loss + + def l2_loss(source, reduction="mean", **data): + """ + Calculate L2 loss (MSE) given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l2_loss + + def ssim_loss(source, reduction="mean", **data): + """ + Calculate SSIM loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + if reduction != "mean": + raise AssertionError( + f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}." + ) + + source_abs, target_abs = self.cropper(source, data["target"], resolution) + data_range = torch.tensor([target_abs.max()], device=target_abs.device) + + ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range) + + return ssim_loss + + # Build losses + loss_dict = {} + for curr_loss in self.cfg.training.loss.losses: # type: ignore + loss_fn = curr_loss.function + if loss_fn == "l1_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss) + elif loss_fn == "l2_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss) + elif loss_fn == "ssim_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss) + else: + raise ValueError(f"{loss_fn} not permissible.") + + return loss_dict + + @torch.no_grad() + def evaluate( + self, + data_loader: DataLoader, + loss_fns: Optional[Dict[str, Callable]], + regularizer_fns: Optional[Dict[str, Callable]] = None, + crop: Optional[str] = None, + is_validation_process: bool = True, + ): + """ + Validation process. Assumes that each batch only contains slices of the same volume *AND* that these + are sequentially ordered. + + Parameters + ---------- + data_loader : DataLoader + loss_fns : Dict[str, Callable], optional + regularizer_fns : Dict[str, Callable], optional + crop : str, optional + is_validation_process : bool + + Returns + ------- + loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + """ + self.models_to_device() + self.models_validation_mode() + torch.cuda.empty_cache() + + # Variables required for evaluation. + volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore + + # filenames can be in the volume_indices attribute of the dataset + num_for_this_process = None + all_filenames = None + if hasattr(data_loader.dataset, "volume_indices"): + all_filenames = list(data_loader.dataset.volume_indices.keys()) + num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys())) + self.logger.info( + f"Reconstructing a total of {len(all_filenames)} volumes. " + f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})." + ) + + filenames_seen = 0 + reconstruction_output: DefaultDict = defaultdict(list) + if is_validation_process: + targets_output: DefaultDict = defaultdict(list) + val_losses = [] + val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict) + last_filename = None + + # Container to for the slices which can be visualized in TensorBoard. + visualize_slices: List[np.ndarray] = [] + visualize_target: List[np.ndarray] = [] + # visualizations = {} + + extra_visualization_keys = ( + self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore + ) + + # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler + # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is + # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only + # contains data from one volume. + time_start = time.time() + + for iter_idx, data in enumerate(data_loader): + filenames = data.pop("filename") + if len(set(filenames)) != 1: + raise ValueError( + f"Expected a batch during validation to only contain filenames of one case. " + f"Got {set(filenames)}." + ) + + slice_nos = data.pop("slice_no") + scaling_factors = data["scaling_factor"] + + resolution = self.compute_resolution( + key=self.cfg.validation.crop, # type: ignore + reconstruction_size=data.get("reconstruction_size", None), + ) + + # Compute output and loss. + iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns) + output = iteration_output.output_image + loss_dict = iteration_output.data_dict + + loss_dict = detach_dict(loss_dict) + output = output.detach() + val_losses.append(loss_dict) + + # Output is complex-valued, and has to be cropped. This holds for both output and target. + # Output has shape (batch, complex, height, width) + output_abs = self.process_output( + output, + scaling_factors, + resolution=resolution, + ) + + if is_validation_process: + # Target has shape (batch, height, width) + target_abs = self.process_output( + data["target"].detach(), + scaling_factors, + resolution=resolution, + ) + for key in extra_visualization_keys: + curr_data = data[key].detach() + # Here we need to discover which keys are actually normalized or not + # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23 + + del output # Explicitly call delete to clear memory. + + # Aggregate volumes to be able to compute the metrics on complete volumes. + for idx, filename in enumerate(filenames): + if last_filename is None: + last_filename = filename # First iteration last_filename is not set. + + curr_slice = output_abs[idx].detach() + slice_no = int(slice_nos[idx].numpy()) + + reconstruction_output[filename].append((slice_no, curr_slice.cpu())) + + if is_validation_process: + targets_output[filename].append((slice_no, target_abs[idx].cpu())) + + is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"]) + reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch] + for condition in reconstruction_conditions: + if condition: + filenames_seen += 1 + + # Now we can ditch the reconstruction dict by reconstructing the volume, + # will take too much memory otherwise. + volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]]) + if is_validation_process: + target = torch.stack([_[1] for _ in targets_output[last_filename]]) + curr_metrics = { + metric_name: metric_fn(target, volume) + for metric_name, metric_fn in volume_metrics.items() + } + val_volume_metrics[last_filename] = curr_metrics + # Log the center slice of the volume + if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore + visualize_slices.append(volume[volume.shape[0] // 2]) + visualize_target.append(target[target.shape[0] // 2]) + + # Delete outputs from memory, and recreate dictionary. + # This is not needed when not in validation as we are actually interested + # in the iteration output. + del targets_output[last_filename] + del reconstruction_output[last_filename] + + if all_filenames: + log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:" + else: + log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:" + + self.logger.info( + f"{log_prefix} {last_filename}" + f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s." + ) + # restart timer + time_start = time.time() + last_filename = filename + + # Average loss dict + loss_dict = reduce_list_of_dicts(val_losses) + reduce_tensor_dict(loss_dict) + + communication.synchronize() + torch.cuda.empty_cache() + + all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics)) + if not is_validation_process: + return loss_dict, reconstruction_output + + return loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + def process_output(self, data, scaling_factors=None, resolution=None): + # data is of shape (batch, complex=2, height, width) + if scaling_factors is not None: + data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device) + + data = T.modulus_if_complex(data) + + if len(data.shape) == 3: # (batch, height, width) + data = data.unsqueeze(1) # Added channel dimension. + + if resolution is not None: + data = T.center_crop(data, resolution).contiguous() + + return data + + @staticmethod + def compute_resolution(key, reconstruction_size): + if key == "header": + # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over + # batches. + resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size] + # The volume sampler should give validation indices belonging to the *same* volume, so it should be + # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y). + resolution = [_[0] for _ in resolution][:-1] + elif key == "training": + resolution = key + elif not key: + resolution = None + else: + raise ValueError( + "Cropping should be either set to `header` to get the values from the header or " + "`training` to take the same value as training." + ) + return resolution + + def cropper(self, source, target, resolution): + """ + 2D source/target cropper + + Parameters: + ----------- + Source has shape (batch, height, width) + Target has shape (batch, height, width) + + """ + + if not resolution or all(_ == 0 for _ in resolution): + return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension. + + source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension. + target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension. + + return source_abs, target_abs + + def compute_model_per_coil(self, model_name, data): + """ + Computes model per coil. + """ + # data is of shape (batch, coil, complex=2, height, width) + output = [] + + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + output.append(self.models[model_name](subselected_data)) + output = torch.stack(output, dim=self._coil_dim) + + # output is of shape (batch, coil, complex=2, height, width) + return output diff --git a/direct/nn/lpd/tests/__init__.py b/direct/nn/lpd/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/lpd/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/lpd/tests/test_lpd.py b/direct/nn/lpd/tests/test_lpd.py new file mode 100644 index 00000000..b86849f2 --- /dev/null +++ b/direct/nn/lpd/tests/test_lpd.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch + +from direct.data.transforms import fft2, ifft2 +from direct.nn.lpd.lpd import LPDNet + + +def create_input(shape): + + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [3, 3, 32, 32], + ], +) +@pytest.mark.parametrize( + "num_iter", + [2, 3], +) +@pytest.mark.parametrize( + "num_primal", + [2, 3], +) +@pytest.mark.parametrize( + "num_dual", + [3], +) +@pytest.mark.parametrize( + "primal_model_architecture", + ["MWCNN", "UNET", "NORMUNET"], +) +@pytest.mark.parametrize( + "dual_model_architecture", + ["CONV", "DIDN", "UNET", "NORMUNET"], +) +def test_lpd( + shape, + num_iter, + num_primal, + num_dual, + primal_model_architecture, + dual_model_architecture, +): + model = LPDNet( + fft2, + ifft2, + num_iter=num_iter, + num_primal=num_primal, + num_dual=num_dual, + primal_model_architecture=primal_model_architecture, + dual_model_architecture=dual_model_architecture, + ).cpu() + + kspace = create_input(shape + [2]).cpu() + sens = create_input(shape + [2]).cpu() + mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu() + + out = model(kspace, sens, mask) + + assert list(out.shape) == [shape[0]] + shape[2:] + [2] diff --git a/direct/nn/mobilenet/mobilenet.py b/direct/nn/mobilenet/mobilenet.py index f7cb5d90..44930e31 100644 --- a/direct/nn/mobilenet/mobilenet.py +++ b/direct/nn/mobilenet/mobilenet.py @@ -18,10 +18,6 @@ def _make_divisible(v, divisor, min_value=None): It ensures that all layers have a channel number that is divisible by 8 It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py - :param v: - :param divisor: - :param min_value: - :return: """ if min_value is None: min_value = divisor @@ -37,7 +33,7 @@ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, nor padding = (kernel_size - 1) // 2 if norm_layer is None: norm_layer = nn.BatchNorm2d - super(ConvBNReLU, self).__init__( + super().__init__( nn.Conv2d( in_planes, out_planes, @@ -54,7 +50,8 @@ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, nor class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None): - super(InvertedResidual, self).__init__() + super().__init__() + self.stride = stride if stride not in [1, 2]: raise AssertionError @@ -123,7 +120,8 @@ def __init__( norm_layer : str Module specifying the normalization layer to use. """ - super(MobileNetV2, self).__init__() + + super().__init__() if block is None: block = InvertedResidual @@ -163,8 +161,8 @@ def __init__( # building inverted residual blocks for t, c, n, s in inverted_residual_setting: output_channel = _make_divisible(c * width_mult, round_nearest) - for i in range(n): - stride = s if i == 0 else 1 + for idx in range(n): + stride = s if idx == 0 else 1 features.append( block( input_channel, diff --git a/direct/nn/multidomainnet/__init__.py b/direct/nn/multidomainnet/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/multidomainnet/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/multidomainnet/config.py b/direct/nn/multidomainnet/config.py new file mode 100644 index 00000000..c46ab526 --- /dev/null +++ b/direct/nn/multidomainnet/config.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors +from dataclasses import dataclass +from typing import Optional, Tuple + +from direct.config.defaults import ModelConfig + + +@dataclass +class MultiDomainNetConfig(ModelConfig): + standardization: bool = True + num_filters: int = 16 + num_pool_layers: int = 4 + dropout_probability: float = 0.0 diff --git a/direct/nn/multidomainnet/multidomain.py b/direct/nn/multidomainnet/multidomain.py new file mode 100644 index 00000000..51505285 --- /dev/null +++ b/direct/nn/multidomainnet/multidomain.py @@ -0,0 +1,309 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MultiDomainConv2d(nn.Module): + def __init__( + self, + forward_operator, + backward_operator, + in_channels, + out_channels, + **kwargs, + ): + super().__init__() + + self.image_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) + self.kspace_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self._channels_dim = 1 + self._spatial_dims = (1, 2) + + def forward(self, image): + kspace = [ + self.forward_operator( + im, + dim=self._spatial_dims, + ) + for im in torch.split(image.permute(0, 2, 3, 1).contiguous(), 2, -1) + ] + kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2) + kspace = self.kspace_conv(kspace) + + backward = [ + self.backward_operator( + ks, + dim=self._spatial_dims, + ) + for ks in torch.split(kspace.permute(0, 2, 3, 1).contiguous(), 2, -1) + ] + backward = torch.cat(backward, -1).permute(0, 3, 1, 2) + + image = self.image_conv(image) + image = torch.cat([image, backward], dim=self._channels_dim) + return image + + +class MultiDomainConvTranspose2d(nn.Module): + def __init__( + self, + forward_operator, + backward_operator, + in_channels, + out_channels, + **kwargs, + ): + super().__init__() + + self.image_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) + self.kspace_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self._channels_dim = 1 + self._spatial_dims = (1, 2) + + def forward(self, image): + kspace = [ + self.forward_operator( + im, + dim=self._spatial_dims, + ) + for im in torch.split(image.permute(0, 2, 3, 1).contiguous(), 2, -1) + ] + kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2) + kspace = self.kspace_conv(kspace) + + backward = [ + self.backward_operator( + ks, + dim=self._spatial_dims, + ) + for ks in torch.split(kspace.permute(0, 2, 3, 1).contiguous(), 2, -1) + ] + backward = torch.cat(backward, -1).permute(0, 3, 1, 2) + + image = self.image_conv(image) + return torch.cat([image, backward], dim=self._channels_dim) + + +class MultiDomainConvBlock(nn.Module): + """ + A multi-domain convolutional block that consists of two multi-domain convolution layers each followed by + instance normalization, LeakyReLU activation and dropout. + """ + + def __init__( + self, forward_operator, backward_operator, in_channels: int, out_channels: int, dropout_probability: float + ): + """ + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + dropout_probability : float + Dropout probability. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.dropout_probability = dropout_probability + + self.layers = nn.Sequential( + MultiDomainConv2d( + forward_operator, backward_operator, in_channels, out_channels, kernel_size=3, padding=1, bias=False + ), + nn.InstanceNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(dropout_probability), + MultiDomainConv2d( + forward_operator, backward_operator, out_channels, out_channels, kernel_size=3, padding=1, bias=False + ), + nn.InstanceNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(dropout_probability), + ) + + def forward(self, input: torch.Tensor): + """ + + Parameters + ---------- + input : torch.Tensor + + Returns + ------- + torch.Tensor + """ + return self.layers(input) + + def __repr__(self): + return ( + f"MultiDomainConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"dropout_probability={self.dropout_probability})" + ) + + +class TransposeMultiDomainConvBlock(nn.Module): + """ + A Transpose Convolutional Block that consists of one convolution transpose layers followed by + instance normalization and LeakyReLU activation. + """ + + def __init__(self, forward_operator, backward_operator, in_channels: int, out_channels: int): + """ + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.layers = nn.Sequential( + MultiDomainConvTranspose2d( + forward_operator, backward_operator, in_channels, out_channels, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, input: torch.Tensor): + """ + + Parameters + ---------- + input : torch.Tensor + + Returns + ------- + torch.Tensor + """ + return self.layers(input) + + def __repr__(self): + return f"MultiDomainConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels})" + + +class MultiDomainUnet2d(nn.Module): + """ + Unet modification to be used with Multi-domain network as in AIRS Medical submission to the Fast MRI 2020 challenge. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + in_channels: int, + out_channels: int, + num_filters: int, + num_pool_layers: int, + dropout_probability: float, + ): + """ + + Parameters + ---------- + forward_operator : Callable + Forward Operator. + backward_operator : Callable + Backward Operator. + in_channels : int + Number of input channels to the u-net. + out_channels : int + Number of output channels to the u-net. + num_filters : int + Number of output channels of the first convolutional layer. + num_pool_layers : int + Number of down-sampling and up-sampling layers (depth). + dropout_probability : float + Dropout probability. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_filters = num_filters + self.num_pool_layers = num_pool_layers + self.dropout_probability = dropout_probability + + self.down_sample_layers = nn.ModuleList( + [MultiDomainConvBlock(forward_operator, backward_operator, in_channels, num_filters, dropout_probability)] + ) + ch = num_filters + for _ in range(num_pool_layers - 1): + self.down_sample_layers += [ + MultiDomainConvBlock(forward_operator, backward_operator, ch, ch * 2, dropout_probability) + ] + ch *= 2 + self.conv = MultiDomainConvBlock(forward_operator, backward_operator, ch, ch * 2, dropout_probability) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv += [TransposeMultiDomainConvBlock(forward_operator, backward_operator, ch * 2, ch)] + self.up_conv += [ + MultiDomainConvBlock(forward_operator, backward_operator, ch * 2, ch, dropout_probability) + ] + ch //= 2 + + self.up_transpose_conv += [TransposeMultiDomainConvBlock(forward_operator, backward_operator, ch * 2, ch)] + self.up_conv += [ + nn.Sequential( + MultiDomainConvBlock(forward_operator, backward_operator, ch * 2, ch, dropout_probability), + nn.Conv2d(ch, self.out_channels, kernel_size=1, stride=1), + ) + ] + + def forward(self, input: torch.Tensor): + """ + + Parameters + ---------- + input : torch.Tensor + + Returns + ------- + torch.Tensor + """ + stack = [] + output = input + + # Apply down-sampling layers + for _, layer in enumerate(self.down_sample_layers): + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output = self.conv(output) + + # Apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # Reflect pad on the right/bottom if needed to handle odd input dimensions. + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # Padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # Padding bottom + if sum(padding) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output diff --git a/direct/nn/multidomainnet/multidomainnet.py b/direct/nn/multidomainnet/multidomainnet.py new file mode 100644 index 00000000..38969c22 --- /dev/null +++ b/direct/nn/multidomainnet/multidomainnet.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from typing import Callable + +import direct.data.transforms as T +from direct.nn.multidomainnet.multidomain import MultiDomainUnet2d + +import torch +import torch.nn as nn + + +class StandardizationLayer(nn.Module): + """ + Multi-channel data standardization method. Inspired by AIRS model submission to the Fast MRI 2020 challenge. + Given individual coil images :math: {x_i}_{i=1}^{N_c} and sensitivity coil maps :math: {S_i}_{i=1}^{N_c} + it returns + .. math:: + {xres_i}_{i=1}^{N_c}, + where :math: xres_i = [x_{sense}, xi - S_i \times x_{sense}] + and :math: x_{sense} = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i. + + """ + + def __init__(self, coil_dim=1, channel_dim=-1): + super().__init__() + self.coil_dim = coil_dim + self.channel_dim = channel_dim + + def forward(self, coil_images: torch.Tensor, sensitivity_map: torch.Tensor) -> torch.Tensor: + combined_image = T.reduce_operator(coil_images, sensitivity_map, self.coil_dim) + residual_image = combined_image.unsqueeze(self.coil_dim) - T.complex_multiplication( + sensitivity_map, combined_image.unsqueeze(self.coil_dim) + ) + concat = torch.cat( + [ + torch.cat([combined_image, residual_image.select(self.coil_dim, idx)], self.channel_dim).unsqueeze( + self.coil_dim + ) + for idx in range(coil_images.size(self.coil_dim)) + ], + self.coil_dim, + ) + return concat + + +class MultiDomainNet(nn.Module): + """ + Feature-level multi-domain module. Inspired by AIRS Medical submission to the Fast MRI 2020 challenge. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + standardization: bool = True, + num_filters: int = 16, + num_pool_layers: int = 4, + dropout_probability: float = 0.0, + **kwargs, + ): + """ + + Parameters + ---------- + forward_operator : Callable + Forward Operator. + backward_operator : Callable + Backward Operator. + standardization : bool + If True standardization is used. Default: True. + num_filters : int + Number of filters for the MultiDomainUnet module. Default: 16. + num_pool_layers : int + Number of pooling layers for the MultiDomainUnet module. Default: 4. + dropout_probability : float + Dropout probability for the MultiDomainUnet module. Default: 0.0. + """ + super().__init__() + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + if standardization: + self.standardization = StandardizationLayer(self._coil_dim, self._complex_dim) + + self.unet = MultiDomainUnet2d( + forward_operator, + backward_operator, + in_channels=4 if standardization else 2, # if standardization, in_channels is 4 due to standardized input + out_channels=2, + num_filters=num_filters, + num_pool_layers=num_pool_layers, + dropout_probability=dropout_probability, + ) + + def _compute_model_per_coil(self, model, data): + """ + Computes model per coil. + """ + output = [] + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + output.append(model(subselected_data)) + output = torch.stack(output, dim=self._coil_dim) + return output + + def forward(self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor) -> torch.Tensor: + """ + + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map : torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). + + Returns + ------- + output_image : torch.Tensor + Multi-coil output image of shape (N, coil, height, width, complex=2). + """ + input_image = self.backward_operator(masked_kspace, dim=self._spatial_dims) + if hasattr(self, "standardization"): + input_image = self.standardization(input_image, sensitivity_map) + output_image = self._compute_model_per_coil(self.unet, input_image.permute(0, 1, 4, 2, 3)).permute( + 0, 1, 3, 4, 2 + ) + return output_image diff --git a/direct/nn/multidomainnet/multidomainnet_engine.py b/direct/nn/multidomainnet/multidomainnet_engine.py new file mode 100644 index 00000000..0f4c8b41 --- /dev/null +++ b/direct/nn/multidomainnet/multidomainnet_engine.py @@ -0,0 +1,470 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import time +from collections import defaultdict +from os import PathLike +from typing import Callable, DefaultDict, Dict, List, Optional + +import numpy as np +import torch +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader + +import direct.data.transforms as T +from direct.config import BaseConfig +from direct.engine import DoIterationOutput, Engine +from direct.functionals import SSIMLoss +from direct.utils import ( + communication, + detach_dict, + dict_to_device, + merge_list_of_dicts, + multiply_function, + reduce_list_of_dicts, +) +from direct.utils.communication import reduce_tensor_dict + + +class MultiDomainNetEngine(Engine): + """ + Multi Domain Network Engine. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: int, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._complex_dim = -1 + self._coil_dim = 1 + self._spatial_dims = (2, 3) + + def _do_iteration( + self, + data: Dict[str, torch.Tensor], + loss_fns: Optional[Dict[str, Callable]] = None, + regularizer_fns: Optional[Dict[str, Callable]] = None, + ) -> DoIterationOutput: + + # loss_fns can be done, e.g. during validation + if loss_fns is None: + loss_fns = {} + + if regularizer_fns is None: + regularizer_fns = {} + + loss_dicts = [] + regularizer_dicts = [] + + data = dict_to_device(data, self.device) + + # sensitivity_map of shape (batch, coil, height, width, complex=2) + sensitivity_map = data["sensitivity_map"] + + if "sensitivity_model" in self.models: + + # Move channels to first axis + sensitivity_map = data["sensitivity_map"].permute( + (0, 1, 4, 2, 3) + ) # shape (batch, coil, complex=2, height, width) + + sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute( + (0, 1, 3, 4, 2) + ) # has channel last: shape (batch, coil, height, width, complex=2) + + # The sensitivity map needs to be normalized such that + # So \sum_{i \in \text{coils}} S_i S_i^* = 1 + + sensitivity_map_norm = torch.sqrt( + ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim) + ) # shape (batch, height, width) + sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1) + data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm) + + with autocast(enabled=self.mixed_precision): + + output_multicoil_image = self.model( + masked_kspace=data["masked_kspace"], + sensitivity_map=data["sensitivity_map"], + ) + + output_image = T.root_sum_of_squares( + output_multicoil_image, self._coil_dim, self._complex_dim + ) # shape (batch, height, width) + + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + regularizer_dict = { + k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() + } + + for key, value in loss_dict.items(): + loss_dict[key] = value + loss_fns[key]( + output_image, + **data, + reduction="mean", + ) + + for key, value in regularizer_dict.items(): + regularizer_dict[key] = value + regularizer_fns[key]( + output_image, + **data, + ) + + loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dicts.append(detach_dict(loss_dict)) + regularizer_dicts.append( + detach_dict(regularizer_dict) + ) # Need to detach dict as this is only used for logging. + + # Add the loss dicts. + loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") + regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict, **regularizer_dict}, + ) + + def build_loss(self, **kwargs) -> Dict: + # TODO: Cropper is a processing output tool. + def get_resolution(**data): + """Be careful that this will use the cropping size of the FIRST sample in the batch.""" + return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None)) + + def l1_loss(source, reduction="mean", **data): + """ + Calculate L1 loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l1_loss + + def l2_loss(source, reduction="mean", **data): + """ + Calculate L2 loss (MSE) given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l2_loss + + def ssim_loss(source, reduction="mean", **data): + """ + Calculate SSIM loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + if reduction != "mean": + raise AssertionError( + f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}." + ) + + source_abs, target_abs = self.cropper(source, data["target"], resolution) + data_range = torch.tensor([target_abs.max()], device=target_abs.device) + + ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range) + + return ssim_loss + + # Build losses + loss_dict = {} + for curr_loss in self.cfg.training.loss.losses: # type: ignore + loss_fn = curr_loss.function + if loss_fn == "l1_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss) + elif loss_fn == "l2_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss) + elif loss_fn == "ssim_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss) + else: + raise ValueError(f"{loss_fn} not permissible.") + + return loss_dict + + @torch.no_grad() + def evaluate( + self, + data_loader: DataLoader, + loss_fns: Optional[Dict[str, Callable]], + regularizer_fns: Optional[Dict[str, Callable]] = None, + crop: Optional[str] = None, + is_validation_process: bool = True, + ): + """ + Validation process. Assumes that each batch only contains slices of the same volume *AND* that these + are sequentially ordered. + + Parameters + ---------- + data_loader : DataLoader + loss_fns : Dict[str, Callable], optional + regularizer_fns : Dict[str, Callable], optional + crop : str, optional + is_validation_process : bool + + Returns + ------- + loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + """ + self.models_to_device() + self.models_validation_mode() + torch.cuda.empty_cache() + + # Variables required for evaluation. + volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore + + # filenames can be in the volume_indices attribute of the dataset + num_for_this_process = None + all_filenames = None + if hasattr(data_loader.dataset, "volume_indices"): + all_filenames = list(data_loader.dataset.volume_indices.keys()) + num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys())) + self.logger.info( + f"Reconstructing a total of {len(all_filenames)} volumes. " + f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})." + ) + + filenames_seen = 0 + reconstruction_output: DefaultDict = defaultdict(list) + if is_validation_process: + targets_output: DefaultDict = defaultdict(list) + val_losses = [] + val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict) + last_filename = None + + # Container to for the slices which can be visualized in TensorBoard. + visualize_slices: List[np.ndarray] = [] + visualize_target: List[np.ndarray] = [] + # visualizations = {} + + extra_visualization_keys = ( + self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore + ) + + # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler + # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is + # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only + # contains data from one volume. + time_start = time.time() + + for iter_idx, data in enumerate(data_loader): + filenames = data.pop("filename") + if len(set(filenames)) != 1: + raise ValueError( + f"Expected a batch during validation to only contain filenames of one case. " + f"Got {set(filenames)}." + ) + + slice_nos = data.pop("slice_no") + scaling_factors = data["scaling_factor"] + + resolution = self.compute_resolution( + key=self.cfg.validation.crop, # type: ignore + reconstruction_size=data.get("reconstruction_size", None), + ) + + # Compute output and loss. + iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns) + output = iteration_output.output_image + loss_dict = iteration_output.data_dict + + loss_dict = detach_dict(loss_dict) + output = output.detach() + val_losses.append(loss_dict) + + # Output is complex-valued, and has to be cropped. This holds for both output and target. + # Output has shape (batch, complex, height, width) + output_abs = self.process_output( + output, + scaling_factors, + resolution=resolution, + ) + + if is_validation_process: + # Target has shape (batch, height, width) + target_abs = self.process_output( + data["target"].detach(), + scaling_factors, + resolution=resolution, + ) + for key in extra_visualization_keys: + curr_data = data[key].detach() + # Here we need to discover which keys are actually normalized or not + # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23 + + del output # Explicitly call delete to clear memory. + + # Aggregate volumes to be able to compute the metrics on complete volumes. + for idx, filename in enumerate(filenames): + if last_filename is None: + last_filename = filename # First iteration last_filename is not set. + + curr_slice = output_abs[idx].detach() + slice_no = int(slice_nos[idx].numpy()) + + reconstruction_output[filename].append((slice_no, curr_slice.cpu())) + + if is_validation_process: + targets_output[filename].append((slice_no, target_abs[idx].cpu())) + + is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"]) + reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch] + for condition in reconstruction_conditions: + if condition: + filenames_seen += 1 + + # Now we can ditch the reconstruction dict by reconstructing the volume, + # will take too much memory otherwise. + volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]]) + if is_validation_process: + target = torch.stack([_[1] for _ in targets_output[last_filename]]) + curr_metrics = { + metric_name: metric_fn(target, volume) + for metric_name, metric_fn in volume_metrics.items() + } + val_volume_metrics[last_filename] = curr_metrics + # Log the center slice of the volume + if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore + visualize_slices.append(volume[volume.shape[0] // 2]) + visualize_target.append(target[target.shape[0] // 2]) + + # Delete outputs from memory, and recreate dictionary. + # This is not needed when not in validation as we are actually interested + # in the iteration output. + del targets_output[last_filename] + del reconstruction_output[last_filename] + + if all_filenames: + log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:" + else: + log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:" + + self.logger.info( + f"{log_prefix} {last_filename}" + f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s." + ) + # restart timer + time_start = time.time() + last_filename = filename + + # Average loss dict + loss_dict = reduce_list_of_dicts(val_losses) + reduce_tensor_dict(loss_dict) + + communication.synchronize() + torch.cuda.empty_cache() + + all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics)) + if not is_validation_process: + return loss_dict, reconstruction_output + + return loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + def process_output(self, data, scaling_factors=None, resolution=None): + # data is of shape (batch, complex=2, height, width) + if scaling_factors is not None: + data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device) + + data = T.modulus_if_complex(data) + + if len(data.shape) == 3: # (batch, height, width) + data = data.unsqueeze(1) # Added channel dimension. + + if resolution is not None: + data = T.center_crop(data, resolution).contiguous() + + return data + + @staticmethod + def compute_resolution(key, reconstruction_size): + if key == "header": + # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over + # batches. + resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size] + # The volume sampler should give validation indices belonging to the *same* volume, so it should be + # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y). + resolution = [_[0] for _ in resolution][:-1] + elif key == "training": + resolution = key + elif not key: + resolution = None + else: + raise ValueError( + "Cropping should be either set to `header` to get the values from the header or " + "`training` to take the same value as training." + ) + return resolution + + def cropper(self, source, target, resolution): + """ + 2D source/target cropper + + Parameters: + ----------- + Source has shape (batch, height, width) + Target has shape (batch, height, width) + + """ + + if not resolution or all(_ == 0 for _ in resolution): + return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension. + + source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension. + target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension. + + return source_abs, target_abs + + def compute_model_per_coil(self, model_name, data): + """ + Computes model per coil. + """ + output = [] + + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + output.append(self.models[model_name](subselected_data)) + output = torch.stack(output, dim=self._coil_dim) + + return output diff --git a/direct/nn/multidomainnet/tests/__init__.py b/direct/nn/multidomainnet/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/multidomainnet/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/multidomainnet/tests/test_multidomainnet.py b/direct/nn/multidomainnet/tests/test_multidomainnet.py new file mode 100644 index 00000000..fecbdea4 --- /dev/null +++ b/direct/nn/multidomainnet/tests/test_multidomainnet.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch + +from direct.data.transforms import fft2, ifft2 +from direct.nn.multidomainnet.multidomain import MultiDomainUnet2d +from direct.nn.multidomainnet.multidomainnet import MultiDomainNet + + +def create_input(shape): + + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [2, 2, 16, 16], + [4, 2, 16, 32], + [3, 2, 32, 32], + [3, 2, 40, 20], + ], +) +@pytest.mark.parametrize( + "num_filters", + [4, 8, 16], # powers of 2 +) +@pytest.mark.parametrize( + "num_pool_layers", + [2, 3], +) +def test_multidomainunet2d(shape, num_filters, num_pool_layers): + model = MultiDomainUnet2d( + fft2, + ifft2, + shape[1], + shape[1], + num_filters=num_filters, + num_pool_layers=num_pool_layers, + dropout_probability=0.05, + ).cpu() + + data = create_input(shape).cpu() + + out = model(data) + + assert list(out.shape) == shape + + +@pytest.mark.parametrize( + "shape", + [ + [2, 2, 16, 16], + [4, 2, 16, 32], + [3, 2, 32, 32], + [3, 2, 40, 20], + ], +) +@pytest.mark.parametrize("standardization", [True, False]) +@pytest.mark.parametrize( + "num_filters", + [4, 8], # powers of 2 +) +@pytest.mark.parametrize( + "num_pool_layers", + [2, 3], +) +def test_multidomainnet(shape, standardization, num_filters, num_pool_layers): + + model = MultiDomainNet(fft2, ifft2, standardization, num_filters, num_pool_layers) + + shape = shape + [2] + + kspace = create_input(shape).cpu() + sens = create_input(shape).cpu() + + out = model(kspace, sens) + + assert list(out.shape) == shape diff --git a/direct/nn/mwcnn/__init__.py b/direct/nn/mwcnn/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/mwcnn/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/mwcnn/mwcnn.py b/direct/nn/mwcnn/mwcnn.py new file mode 100644 index 00000000..d6a27c73 --- /dev/null +++ b/direct/nn/mwcnn/mwcnn.py @@ -0,0 +1,328 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from collections import OrderedDict +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DWT(nn.Module): + """ + 2D Discrete Wavelet Transform as implemented in https://arxiv.org/abs/1805.07071. + """ + + def __init__(self): + super().__init__() + self.requires_grad = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x01 = x[:, :, 0::2, :] / 2 + x02 = x[:, :, 1::2, :] / 2 + x1 = x01[:, :, :, 0::2] + x2 = x02[:, :, :, 0::2] + x3 = x01[:, :, :, 1::2] + x4 = x02[:, :, :, 1::2] + x_LL = x1 + x2 + x3 + x4 + x_HL = -x1 - x2 + x3 + x4 + x_LH = -x1 + x2 - x3 + x4 + x_HH = x1 - x2 - x3 + x4 + + return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) + + +class IWT(nn.Module): + """ + 2D Inverse Wavelet Transform as implemented in https://arxiv.org/abs/1805.07071. + """ + + def __init__(self): + super().__init__() + self.requires_grad = False + self._r = 2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + batch, in_channel, in_height, in_width = x.size() + out_channel, out_height, out_width = int(in_channel / (self._r ** 2)), self._r * in_height, self._r * in_width + + x1 = x[:, 0:out_channel, :, :] / 2 + x2 = x[:, out_channel : out_channel * 2, :, :] / 2 + x3 = x[:, out_channel * 2 : out_channel * 3, :, :] / 2 + x4 = x[:, out_channel * 3 : out_channel * 4, :, :] / 2 + + h = torch.zeros([batch, out_channel, out_height, out_width], dtype=x.dtype).to(x.device) + + h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 + h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 + h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 + h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 + + return h + + +class ConvBlock(nn.Module): + """ + Convolution Block for MWCNN as implemented in https://arxiv.org/abs/1805.07071. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + bias: bool = True, + batchnorm: bool = False, + activation: nn.Module = nn.ReLU(True), + scale: Optional[float] = 1.0, + ): + super().__init__() + + net = [] + net.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + padding=kernel_size // 2, + ) + ) + if batchnorm: + net.append(nn.BatchNorm2d(num_features=out_channels, eps=1e-4, momentum=0.95)) + net.append(activation) + + self.net = nn.Sequential(*net) + self.scale = scale + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.net(input) * self.scale + return output + + +class DilatedConvBlock(nn.Module): + """ + Double dilated Convolution Block fpr MWCNN as implemented in https://arxiv.org/abs/1805.07071. + """ + + def __init__( + self, + in_channels: int, + dilations: Tuple[int, int], + kernel_size: int, + out_channels: Optional[int] = None, + bias: bool = True, + batchnorm: bool = False, + activation: nn.Module = nn.ReLU(True), + scale: Optional[float] = 1.0, + ): + super().__init__() + net = [] + net.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + bias=bias, + dilation=dilations[0], + padding=kernel_size // 2 + dilations[0] - 1, + ) + ) + if batchnorm: + net.append(nn.BatchNorm2d(num_features=in_channels, eps=1e-4, momentum=0.95)) + net.append(activation) + if out_channels is None: + out_channels = in_channels + net.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + dilation=dilations[1], + padding=kernel_size // 2 + dilations[1] - 1, + ) + ) + if batchnorm: + net.append(nn.BatchNorm2d(num_features=in_channels, eps=1e-4, momentum=0.95)) + net.append(activation) + + self.net = nn.Sequential(*net) + self.scale = scale + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.net(input) * self.scale + return output + + +class MWCNN(nn.Module): + """ + Multi-level Wavelet CNN (MWCNN) implementation as implemented in https://arxiv.org/abs/1805.07071. + """ + + def __init__( + self, + input_channels: int, + first_conv_hidden_channels: int, + num_scales: int = 4, + bias: bool = True, + batchnorm: bool = False, + activation: nn.Module = nn.ReLU(True), + ): + """ + + Parameters + ---------- + input_channels : int + Input channels dimension. + first_conv_hidden_channels : int + First convolution output channels dimension. + num_scales : int + Number of scales. Default: 4. + bias : bool + Convolution bias. If True, adds a learnable bias to the output. Default: True. + batchnorm : bool + If True, a batchnorm layer is added after each convolution. Default: False. + activation : nn.Module + Activation function applied after each convolution. Default: nn.ReLU(). + """ + super().__init__() + self._kernel_size = 3 + self.DWT = DWT() + self.IWT = IWT() + + self.down = nn.ModuleList() + for idx in range(0, num_scales): + + in_channels = input_channels if idx == 0 else first_conv_hidden_channels * 2 ** (idx + 1) + out_channels = first_conv_hidden_channels * 2 ** idx + dilations = (2, 1) if idx != num_scales - 1 else (2, 3) + self.down.append( + nn.Sequential( + OrderedDict( + [ + ( + f"convblock{idx}", + ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self._kernel_size, + bias=bias, + batchnorm=batchnorm, + activation=activation, + ), + ), + ( + f"dilconvblock{idx}", + DilatedConvBlock( + in_channels=out_channels, + dilations=dilations, + kernel_size=self._kernel_size, + bias=bias, + batchnorm=batchnorm, + activation=activation, + ), + ), + ] + ) + ) + ) + self.up = nn.ModuleList() + for idx in range(num_scales)[::-1]: + + in_channels = first_conv_hidden_channels * 2 ** idx + out_channels = input_channels if idx == 0 else first_conv_hidden_channels * 2 ** (idx + 1) + dilations = (2, 1) if idx != num_scales - 1 else (3, 2) + self.up.append( + nn.Sequential( + OrderedDict( + [ + ( + f"invdilconvblock{num_scales - 2 - idx}", + DilatedConvBlock( + in_channels=in_channels, + dilations=dilations, + kernel_size=self._kernel_size, + bias=bias, + batchnorm=batchnorm, + activation=activation, + ), + ), + ( + f"invconvblock{num_scales - 2 - idx}", + ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self._kernel_size, + bias=bias, + batchnorm=batchnorm, + activation=activation, + ), + ), + ] + ) + ) + ) + self.num_scales = num_scales + + @staticmethod + def pad(x): + padding = [0, 0, 0, 0] + + if x.shape[-2] % 2 != 0: + padding[3] = 1 # Padding right - width + if x.shape[-1] % 2 != 0: + padding[1] = 1 # Padding bottom - height + if sum(padding) != 0: + x = F.pad(x, padding, "reflect") + return x + + @staticmethod + def crop_to_shape(x, shape): + h, w = x.shape[-2:] + + if h > shape[0]: + x = x[:, :, : shape[0], :] + if w > shape[1]: + x = x[:, :, :, : shape[1]] + return x + + def forward(self, input: torch.Tensor, res: bool = False) -> torch.Tensor: + """ + + Parameters + ---------- + input : torch.Tensor + Input tensor. + res : bool + If True, residual connection is applied to the output. Default: False. + + Returns + ------- + torch.Tensor + """ + res_values = [] + x = self.pad(input.clone()) + for idx in range(self.num_scales): + if idx == 0: + x = self.pad(self.down[idx](x)) + res_values.append(x) + elif idx == self.num_scales - 1: + x = self.down[idx](self.DWT(x)) + else: + x = self.pad(self.down[idx](self.DWT(x))) + res_values.append(x) + + for idx in range(self.num_scales): + if idx != self.num_scales - 1: + x = ( + self.crop_to_shape(self.IWT(self.up[idx](x)), res_values[self.num_scales - 2 - idx].shape[-2:]) + + res_values[self.num_scales - 2 - idx] + ) + else: + x = self.crop_to_shape(self.up[idx](x), input.shape[-2:]) + if res: + x += input + return x diff --git a/direct/nn/mwcnn/tests/__init__.py b/direct/nn/mwcnn/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/mwcnn/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/mwcnn/tests/test_mwcnn.py b/direct/nn/mwcnn/tests/test_mwcnn.py new file mode 100644 index 00000000..dd05f7ec --- /dev/null +++ b/direct/nn/mwcnn/tests/test_mwcnn.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch +import torch.nn as nn + +from direct.nn.mwcnn.mwcnn import MWCNN + + +def create_input(shape): + + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [3, 2, 32, 32], + [3, 2, 20, 34], + ], +) +@pytest.mark.parametrize( + "first_conv_hidden_channels", + [4, 8], +) +@pytest.mark.parametrize( + "n_scales", + [2, 3], +) +@pytest.mark.parametrize( + "bias", + [True, False], +) +@pytest.mark.parametrize( + "batchnorm", + [True, False], +) +@pytest.mark.parametrize( + "act", + [nn.ReLU(), nn.PReLU()], +) +def test_mwcnn(shape, first_conv_hidden_channels, n_scales, bias, batchnorm, act): + model = MWCNN(shape[1], first_conv_hidden_channels, n_scales, bias, batchnorm, act) + + data = create_input(shape).cpu() + + out = model(data) + + assert list(out.shape) == shape diff --git a/direct/nn/recurrent/__init__.py b/direct/nn/recurrent/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/recurrent/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/recurrent/recurrent.py b/direct/nn/recurrent/recurrent.py new file mode 100644 index 00000000..05234934 --- /dev/null +++ b/direct/nn/recurrent/recurrent.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from typing import List, Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Conv2dGRU(nn.Module): + """ + 2D Convolutional GRU Network. + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 2, + gru_kernel_size=1, + orthogonal_initialization: bool = True, + instance_norm: bool = False, + dense_connect: int = 0, + replication_padding: bool = True, + ): + """ + + Parameters + ---------- + in_channels : int + Number of input channels. + hidden_channels : int + Number of hidden channels. + out_channels : Optional[int] + Number of output channels. If None, same as in_channels. Default: None. + num_layers : int + Number of layers. Default: 2. + gru_kernel_size : int + Size of the GRU kernel. Default: 1. + orthogonal_initialization : bool + Orthogonal initialization is used if set to True. Default: True. + instance_norm : bool + Instance norm is used if set to True. Default: False. + dense_connect : int + Number of dense connections. + replication_padding : bool + If set to true replication padding is applied. + """ + super().__init__() + + if out_channels is None: + out_channels = in_channels + + self.num_layers = num_layers + self.hidden_channels = hidden_channels + self.dense_connect = dense_connect + + self.reset_gates = nn.ModuleList([]) + self.update_gates = nn.ModuleList([]) + self.out_gates = nn.ModuleList([]) + self.conv_blocks = nn.ModuleList([]) + + # Create convolutional blocks + for idx in range(num_layers + 1): + in_ch = in_channels if idx == 0 else (1 + min(idx, dense_connect)) * hidden_channels + out_ch = hidden_channels if idx < num_layers else out_channels + padding = 0 if replication_padding else (2 if idx == 0 else 1) + block = [] + if replication_padding: + if idx == 1: + block.append(nn.ReplicationPad2d(2)) + else: + block.append(nn.ReplicationPad2d(2 if idx == 0 else 1)) + block.append( + nn.Conv2d( + in_channels=in_ch, + out_channels=out_ch, + kernel_size=5 if idx == 0 else 3, + dilation=(2 if idx == 1 else 1), + padding=padding, + ) + ) + self.conv_blocks.append(nn.Sequential(*block)) + + # Create GRU blocks + for idx in range(num_layers): + for gru_part in [self.reset_gates, self.update_gates, self.out_gates]: + block = [] + if instance_norm: + block.append(nn.InstanceNorm2d(2 * hidden_channels)) + block.append( + nn.Conv2d( + in_channels=2 * hidden_channels, + out_channels=hidden_channels, + kernel_size=gru_kernel_size, + padding=gru_kernel_size // 2, + ) + ) + gru_part.append(nn.Sequential(*block)) + + if orthogonal_initialization: + for reset_gate, update_gate, out_gate in zip(self.reset_gates, self.update_gates, self.out_gates): + nn.init.orthogonal_(reset_gate[-1].weight) + nn.init.orthogonal_(update_gate[-1].weight) + nn.init.orthogonal_(out_gate[-1].weight) + nn.init.constant_(reset_gate[-1].bias, -1.0) + nn.init.constant_(update_gate[-1].bias, 0.0) + nn.init.constant_(out_gate[-1].bias, 0.0) + + def forward( + self, + cell_input: torch.Tensor, + previous_state: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + + Parameters + ---------- + cell_input : torch.Tensor + Reconstruction input + previous_state : torch.Tensor + Tensor of previous states. + + Returns + ------- + (torch.Tensor, torch.Tensor) + """ + new_states: List[torch.Tensor] = [] + conv_skip: List[torch.Tensor] = [] + + if previous_state is None: + batch_size, spatial_size = cell_input.size(0), (cell_input.size(2), cell_input.size(3)) + state_size = [batch_size, self.hidden_channels] + list(spatial_size) + [self.num_layers] + previous_state = torch.zeros(*state_size, dtype=cell_input.dtype).to(cell_input.device) + + for idx in range(self.num_layers): + if len(conv_skip) > 0: + cell_input = F.relu( + self.conv_blocks[idx](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1)), + inplace=True, + ) + else: + cell_input = F.relu(self.conv_blocks[idx](cell_input), inplace=True) + if self.dense_connect > 0: + conv_skip.append(cell_input) + + stacked_inputs = torch.cat([cell_input, previous_state[:, :, :, :, idx]], dim=1) + + update = torch.sigmoid(self.update_gates[idx](stacked_inputs)) + reset = torch.sigmoid(self.reset_gates[idx](stacked_inputs)) + delta = torch.tanh( + self.out_gates[idx](torch.cat([cell_input, previous_state[:, :, :, :, idx] * reset], dim=1)) + ) + cell_input = previous_state[:, :, :, :, idx] * (1 - update) + delta * update + new_states.append(cell_input) + cell_input = F.relu(cell_input, inplace=False) + if len(conv_skip) > 0: + out = self.conv_blocks[self.num_layers](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1)) + else: + out = self.conv_blocks[self.num_layers](cell_input) + + return out, torch.stack(new_states, dim=-1) diff --git a/direct/nn/recurrent/tests/__init__.py b/direct/nn/recurrent/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/recurrent/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/recurrent/tests/test_recurrent.py b/direct/nn/recurrent/tests/test_recurrent.py new file mode 100644 index 00000000..aad78200 --- /dev/null +++ b/direct/nn/recurrent/tests/test_recurrent.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch + +from direct.nn.recurrent.recurrent import Conv2dGRU + + +def create_input(shape): + + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [3, 2, 32, 32], + [3, 2, 16, 16], + ], +) +@pytest.mark.parametrize( + "hidden_channels", + [4, 8], +) +def test_conv2dgru(shape, hidden_channels): + model = Conv2dGRU(shape[1], hidden_channels, shape[1]) + data = create_input(shape).cpu() + + out = model(data, None)[0] + + assert list(out.shape) == shape diff --git a/direct/nn/rim/rim.py b/direct/nn/rim/rim.py index 6c20522c..6d903744 100644 --- a/direct/nn/rim/rim.py +++ b/direct/nn/rim/rim.py @@ -2,7 +2,7 @@ # Copyright (c) DIRECT Contributors import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple import numpy as np import torch @@ -11,136 +11,18 @@ from direct.data import transforms as T from direct.utils.asserts import assert_positive_integer +from direct.nn.recurrent.recurrent import Conv2dGRU -class ConvGRUCell(nn.Module): +class MRILogLikelihood(nn.Module): """ - Convolutional GRU Cell to be used with RIM (Recurrent Inference Machines). + Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils. + .. math:: + \frac{1}{\sigma^2} \sum_{i}^{\text{num coils}} + {S}_i^\{text{H}} \mathcal{F}^{-1} P^T (P \mathcal{F} S_i x_\tau - y_\tau) + for each time step :math: \tau. """ - def __init__( - self, - x_channels: int, - hidden_channels, - depth=2, - gru_kernel_size=1, - ortho_init: bool = True, - instance_norm: bool = False, - dense_connect=0, - replication_padding=False, - ): - super().__init__() - self.depth = depth - self.x_channels = x_channels - self.hidden_channels = hidden_channels - self.instance_norm = instance_norm - self.dense_connect = dense_connect - self.repl_pad = replication_padding - - self.reset_gates = nn.ModuleList([]) - self.update_gates = nn.ModuleList([]) - self.out_gates = nn.ModuleList([]) - self.conv_blocks = nn.ModuleList([]) - - # Create convolutional blocks of RIM cell - for idx in range(depth + 1): - in_ch = x_channels + 2 if idx == 0 else (1 + min(idx, dense_connect)) * hidden_channels - out_ch = hidden_channels if idx < depth else x_channels - pad = 0 if replication_padding else (2 if idx == 0 else 1) - block = [] - if replication_padding: - if idx == 1: - block.append(nn.ReplicationPad2d(2)) - else: - block.append(nn.ReplicationPad2d(2 if idx == 0 else 1)) - block.append( - nn.Conv2d( - in_ch, - out_ch, - 5 if idx == 0 else 3, - dilation=(2 if idx == 1 else 1), - padding=pad, - ) - ) - self.conv_blocks.append(nn.Sequential(*block)) - - # Create GRU blocks of RIM cell - for idx in range(depth): - for gru_part in [self.reset_gates, self.update_gates, self.out_gates]: - block = [] - if instance_norm: - block.append(nn.InstanceNorm2d(2 * hidden_channels)) - block.append( - nn.Conv2d( - 2 * hidden_channels, - hidden_channels, - gru_kernel_size, - padding=gru_kernel_size // 2, - ) - ) - gru_part.append(nn.Sequential(*block)) - - if ortho_init: - for reset_gate, update_gate, out_gate in zip(self.reset_gates, self.update_gates, self.out_gates): - nn.init.orthogonal_(reset_gate[-1].weight) - nn.init.orthogonal_(update_gate[-1].weight) - nn.init.orthogonal_(out_gate[-1].weight) - nn.init.constant_(reset_gate[-1].bias, -1.0) - nn.init.constant_(update_gate[-1].bias, 0.0) - nn.init.constant_(out_gate[-1].bias, 0.0) - - def forward( - self, - cell_input: torch.Tensor, - previous_state: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - - Parameters - ---------- - cell_input : torch.Tensor - Reconstruction input - previous_state : torch.Tensor - Tensor of previous stats. - - Returns - ------- - (torch.Tensor, torch.Tensor) - """ - - new_states: List[torch.Tensor] = [] - conv_skip: List[torch.Tensor] = [] - - for idx in range(self.depth): - if len(conv_skip) > 0: - cell_input = F.relu( - self.conv_blocks[idx](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1)), - inplace=True, - ) - else: - cell_input = F.relu(self.conv_blocks[idx](cell_input), inplace=True) - if self.dense_connect > 0: - conv_skip.append(cell_input) - - stacked_inputs = torch.cat([cell_input, previous_state[:, :, :, :, idx]], dim=1) - - update = torch.sigmoid(self.update_gates[idx](stacked_inputs)) - reset = torch.sigmoid(self.reset_gates[idx](stacked_inputs)) - delta = torch.tanh( - self.out_gates[idx](torch.cat([cell_input, previous_state[:, :, :, :, idx] * reset], dim=1)) - ) - cell_input = previous_state[:, :, :, :, idx] * (1 - update) + delta * update - new_states.append(cell_input) - cell_input = F.relu(cell_input, inplace=False) - if len(conv_skip) > 0: - out = self.conv_blocks[self.depth](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1)) - else: - out = self.conv_blocks[self.depth](cell_input) - - return out, torch.stack(new_states, dim=-1) - - -class MRILogLikelihood(nn.Module): def __init__( self, forward_operator: Callable, @@ -151,11 +33,8 @@ def __init__( self.forward_operator = forward_operator self.backward_operator = backward_operator - # TODO UGLY - self.ndim = 2 - self._coil_dim = 1 - self._spatial_dims = (2, 3) if self.ndim == 2 else (2, 3, 4) + self._spatial_dims = (2, 3) def forward( self, @@ -165,21 +44,17 @@ def forward( sampling_mask, loglikelihood_scaling=None, ) -> torch.Tensor: - r""" - Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils. - $$ \frac{1}{\sigma^2} \sum_{i}^{\text{num coils}} - {S}_i^\{text{H}} \mathcal{F}^{-1} P^T (P \mathcal{F} S_i x_\tau - y_\tau)$$ - for each time step $\tau$ + """ Parameters ---------- input_image : torch.tensor Initial or previous iteration of image with complex first - of shape (batch, complex, [slice,] height, width). + of shape (N, complex, height, width). masked_kspace : torch.tensor - Masked k-space of shape (batch, coil, [slice,] height, width, complex). + Masked k-space of shape (N, coil, height, width, complex). sensitivity_map : torch.tensor - Sensitivity Map of shape (batch, coil, [slice,] height, width, complex). + Sensitivity Map of shape (N, coil, height, width, complex). sampling_mask : torch.tensor loglikelihood_scaling : torch.tensor Multiplier for loglikelihood, for instance for the k-space noise, of shape (1,). @@ -188,54 +63,57 @@ def forward( ------- torch.Tensor """ - if input_image.ndim == 5: - self.ndim = 3 - input_image = input_image.permute( - (0, 2, 3, 1) if self.ndim == 2 else (0, 2, 3, 4, 1) - ) # shape (batch, [slice,] height, width, complex) + input_image = input_image.permute(0, 2, 3, 1) # shape (N, height, width, complex) + if loglikelihood_scaling is not None: + loglikelihood_scaling = loglikelihood_scaling + else: + loglikelihood_scaling = torch.tensor([1.0], dtype=masked_kspace.dtype).to(masked_kspace.device) loglikelihood_scaling = loglikelihood_scaling.reshape( list(torch.ones(len(sensitivity_map.shape)).int()) - ) # shape (1, 1, 1, [1,] 1, 1) + ) # shape (1, 1, 1, 1, 1) # We multiply by the loglikelihood_scaling here to prevent fp16 information loss, # as this value is typically <<1, and the operators are linear. mul = loglikelihood_scaling * T.complex_multiplication( - sensitivity_map, input_image.unsqueeze(1) # (batch, 1, [slice,] height, width, complex) - ) # shape (batch, coil, [slice,] height, width, complex) + sensitivity_map, input_image.unsqueeze(1) # (N, 1, height, width, complex) + ) # shape (N, coil, height, width, complex) mr_forward = torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), self.forward_operator(mul, dim=self._spatial_dims), - ) # shape (batch, coil, [slice], height, width, complex) + ) # shape (N, coil, height, width, complex) error = mr_forward - loglikelihood_scaling * torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), masked_kspace, - ) # shape (batch, coil, [slice], height, width, complex) + ) # shape (N, coil, height, width, complex) - mr_backward = self.backward_operator( - error, dim=self._spatial_dims - ) # shape (batch, coil, [slice], height, width, complex) + mr_backward = self.backward_operator(error, dim=self._spatial_dims) # shape (N, coil, height, width, complex) if sensitivity_map is not None: out = T.complex_multiplication(T.conjugate(sensitivity_map), mr_backward).sum(self._coil_dim) else: out = mr_backward.sum(self._coil_dim) - # out has shape (batch, complex=2, [slice], height, width) + # out has shape (N, complex=2, height, width) - out = ( - out.permute(0, 3, 1, 2) if self.ndim == 2 else out.permute(0, 4, 1, 2, 3) - ) # complex first: shape (batch, [slice], height, width, complex=2) + out = out.permute(0, 3, 1, 2) # complex first: shape (N, height, width, complex=2) return out class RIMInit(nn.Module): + """ + Learned initializer for RIM, based on multi-scale context aggregation with dilated convolutions, that replaces + zero initializer for the RIM hidden vector. + + Inspired by "Multi-Scale Context Aggregation by Dilated Convolutions" (https://arxiv.org/abs/1511.07122) + """ + def __init__( self, x_ch: int, @@ -246,10 +124,6 @@ def __init__( multiscale_depth: int = 1, ): """ - Learned initializer for RIM, based on multi-scale context aggregation with dilated convolutions, that replaces - zero initializer for the RIM hidden vector. - - Inspired by "Multi-Scale Context Aggregation by Dilated Convolutions" (https://arxiv.org/abs/1511.07122) Parameters ---------- @@ -268,6 +142,7 @@ def __init__( """ super().__init__() + self.conv_blocks = nn.ModuleList() self.out_blocks = nn.ModuleList() self.depth = depth @@ -373,10 +248,11 @@ def __init__( self.no_parameter_sharing = no_parameter_sharing for _ in range(length if no_parameter_sharing else 1): self.cell_list.append( - ConvGRUCell( - x_channels, - hidden_channels, - depth=depth, + Conv2dGRU( + in_channels=x_channels * 2, # double channels as input is concatenated image and gradient + out_channels=x_channels, + hidden_channels=hidden_channels, + num_layers=depth, instance_norm=instance_norm, dense_connect=dense_connect, replication_padding=replication_padding, @@ -386,18 +262,21 @@ def __init__( self.length = length self.depth = depth - def compute_sense_init(self, kspace, sensitivity_map, spatial_dims=(2, 3), coil_dim=1): - # kspace is of shape: (batch, coil, [slice,] height, width, complex) - # sensitivity_map is of shape (batch, coil, [slice,] height, width, complex) + self._coil_dim = 1 + self._spatial_dims = (2, 3) + + def compute_sense_init(self, kspace, sensitivity_map): + # kspace is of shape: (N, coil, height, width, complex) + # sensitivity_map is of shape (N, coil, height, width, complex) input_image = T.complex_multiplication( T.conjugate(sensitivity_map), - self.backward_operator(kspace, dim=spatial_dims), - ) # shape (batch, coil, [slice,] height, width, complex=2) + self.backward_operator(kspace, dim=self._spatial_dims), + ) # shape (N, coil, height, width, complex=2) - input_image = input_image.sum(coil_dim) + input_image = input_image.sum(self._coil_dim) - # shape (batch, [slice,] height, width, complex=2) + # shape (N, height, width, complex=2) return input_image def forward( @@ -414,13 +293,13 @@ def forward( Parameters ---------- input_image : torch.Tensor - Initial or intermediate guess of input. Has shape (batch, [slice,] height, width, complex=2). + Initial or intermediate guess of input. Has shape (N, height, width, complex=2). masked_kspace : torch.Tensor - Kspace masked by the sampling mask. + Masked k-space of shape (N, coil, height, width, complex=2). sensitivity_map : torch.Tensor - Coil sensitivities. + Sensitivity map of shape (N, coil, height, width, complex=2). sampling_mask : torch.Tensor - Sampling mask. + Sampling mask of shape (N, 1, height, width, 1). previous_state : torch.Tensor loglikelihood_scaling : torch.Tensor Float tensor of shape (1,). @@ -429,13 +308,11 @@ def forward( ------- torch.Tensor """ - if input_image is None: if self.image_initialization == "sense": input_image = self.compute_sense_init( kspace=masked_kspace, sensitivity_map=sensitivity_map, - spatial_dims=(3, 4) if masked_kspace.ndim == 6 else (2, 3), ) elif self.image_initialization == "input_kspace": if "initial_kspace" not in kwargs: @@ -445,7 +322,6 @@ def forward( input_image = self.compute_sense_init( kspace=kwargs["initial_kspace"], sensitivity_map=sensitivity_map, - spatial_dims=(3, 4) if kwargs["initial_kspace"].ndim == 6 else (2, 3), ) elif self.image_initialization == "input_image": if "initial_image" not in kwargs: @@ -462,34 +338,25 @@ def forward( f"Unknown image_initialization. Expected `sense`, `input_kspace`, `'input_image` or `zero_filled`. " f"Got {self.image_initialization}." ) - # Provide an initialization for the first hidden state. if (self.initializer is not None) and (previous_state is None): previous_state = self.initializer( - input_image.permute((0, 4, 1, 2, 3) if input_image.ndim == 5 else (0, 3, 1, 2)) - ) # permute to (batch, complex, [slice], height, width), + input_image.permute(0, 3, 1, 2) + ) # permute to (N, complex, height, width), # TODO: This has to be made contiguous - # TODO(gy): Do 3D data pass from here? If not remove if statements below and [slice,] from comments. - input_image = input_image.permute( - (0, 4, 1, 2, 3) if input_image.ndim == 5 else (0, 3, 1, 2) - ).contiguous() # shape (batch, , complex=2, [slice,] height, width) + input_image = input_image.permute(0, 3, 1, 2).contiguous() # shape (N, complex=2, height, width) batch_size = input_image.size(0) - spatial_shape = ( - [input_image.size(-3), input_image.size(-2), input_image.size(-1)] - if input_image.ndim == 5 - else [input_image.size(-2), input_image.size(-1)] - ) - + spatial_shape = [input_image.size(self._spatial_dims[0]), input_image.size(self._spatial_dims[1])] # Initialize zero state for RIM state_size = [batch_size, self.hidden_channels] + list(spatial_shape) + [self.depth] if previous_state is None: - # shape (batch, hidden_channels, [slice,] height, width, depth) + # shape (N, hidden_channels, height, width, depth) previous_state = torch.zeros(*state_size, dtype=input_image.dtype).to(input_image.device) cell_outputs = [] - intermediate_image = input_image # shape (batch, , complex=2, [slice,] height, width) + intermediate_image = input_image # shape (N, complex=2, height, width) for cell_idx in range(self.length): cell = self.cell_list[cell_idx] if self.no_parameter_sharing else self.cell_list[0] @@ -500,7 +367,7 @@ def forward( sensitivity_map, sampling_mask, loglikelihood_scaling, - ) # shape (batch, , complex=2, [slice,] height, width) + ) # shape (N, complex=2, height, width) if grad_loglikelihood.abs().max() > 150.0: warnings.warn( @@ -511,16 +378,16 @@ def forward( cell_input = torch.cat( [intermediate_image, grad_loglikelihood], dim=1, - ) # shape (batch, , complex=4, [slice,] height, width) + ) # shape (N, complex=4, height, width) cell_output, previous_state = cell(cell_input, previous_state) - # shapes (batch, complex=2, [slice,] height, width), (batch, hidden_channels, [slice,] height, width, depth) + # shapes (N, complex=2, height, width), (N, hidden_channels, height, width, depth) if self.skip_connections: - # shape (batch, complex=2, [slice,] height, width) + # shape (N, complex=2, height, width) intermediate_image = intermediate_image + cell_output else: - # shape (batch, complex=2, [slice,] height, width) + # shape (N, complex=2, height, width) intermediate_image = cell_output if not self.training: diff --git a/direct/nn/rim/rim_engine.py b/direct/nn/rim/rim_engine.py index d30a70e9..522d24be 100644 --- a/direct/nn/rim/rim_engine.py +++ b/direct/nn/rim/rim_engine.py @@ -232,7 +232,7 @@ def ssim_loss(source, reduction="mean", **data): source_abs, target_abs = self.cropper(source, data["target"], resolution) data_range = torch.tensor([target_abs.max()], device=target_abs.device) - ssim_loss = SSIMLoss().to(source_abs.device)(source_abs, target_abs, data_range=data_range) + ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range) return ssim_loss diff --git a/direct/nn/rim/tests/test_mri_models.py b/direct/nn/rim/tests/test_mri_models.py deleted file mode 100644 index 7301dc3a..00000000 --- a/direct/nn/rim/tests/test_mri_models.py +++ /dev/null @@ -1,93 +0,0 @@ -# coding=utf-8 -# Copyright (c) DIRECT Contributors - -import warnings - -import numpy as np -import torch - -from direct.data import transforms as T -from direct.data.transforms import tensor_to_complex_numpy - -warnings.filterwarnings("ignore") - - -def numpy_fft(data, dims=(-2, -1)): - """ - Fast Fourier Transform. - """ - data = np.fft.ifftshift(data, dims) - out = np.fft.fft2(data, norm="ortho") - out = np.fft.fftshift(out, dims) - - return out - - -def numpy_ifft(data, dims=(-2, -1)): - """ - Inverse Fast Fourier Transform. - """ - data = np.fft.ifftshift(data, dims) - out = np.fft.ifft2(data, norm="ortho") - out = np.fft.fftshift(out, dims) - - return out - - -def create_input(shape): - data = np.arange(np.product(shape)).reshape(shape).copy() - data = torch.from_numpy(data).float() - - return data - - -batch, coil, height, width, complex = 3, 15, 100, 80, 2 - -input_image = create_input([batch, height, width, complex]) -sensitivity_map = create_input([batch, coil, height, width, complex]) * 0.1 -masked_kspace = create_input([batch, coil, height, width, complex]) + 0.33 -sampling_mask = torch.from_numpy(np.random.binomial(size=(batch, 1, height, width, 1), n=1, p=0.5)) - -input_image_numpy = tensor_to_complex_numpy(input_image) -sensitivity_map_numpy = tensor_to_complex_numpy(sensitivity_map) -masked_kspace_numpy = tensor_to_complex_numpy(masked_kspace) -sampling_mask_numpy = sampling_mask.numpy()[..., 0] - -mul = T.complex_multiplication(sensitivity_map, input_image.unsqueeze(1)) - -dims = (2, 3) -mr_forward = torch.where( - sampling_mask == 0, - torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), - T.fft2(mul, dim=dims), -) - -error = mr_forward - torch.where( - sampling_mask == 0, - torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), - masked_kspace, -) - -mr_backward = T.ifft2(error, dim=dims) - -coil_dim = 1 -out = T.complex_multiplication(T.conjugate(sensitivity_map), mr_backward).sum(coil_dim) - - -# numpy -mul_numpy = sensitivity_map_numpy * input_image_numpy.reshape(batch, 1, height, width) -mr_forward_numpy = sampling_mask_numpy * numpy_fft(mul_numpy) -error_numpy = mr_forward_numpy - sampling_mask_numpy * masked_kspace_numpy -mr_backward_numpy = numpy_ifft(error_numpy) -out_numpy = (sensitivity_map_numpy.conjugate() * mr_backward_numpy).sum(1) - -np.allclose(tensor_to_complex_numpy(out), out_numpy) - -# numpy 2 -mr_backward_numpy = numpy_ifft( - sampling_mask_numpy * numpy_fft(sensitivity_map_numpy * input_image_numpy[:, np.newaxis, ...]) - - sampling_mask_numpy * masked_kspace_numpy -) -out_numpy = (sensitivity_map_numpy.conjugate() * mr_backward_numpy).sum(1) - -np.allclose(tensor_to_complex_numpy(out), out_numpy) diff --git a/direct/nn/rim/tests/test_rim.py b/direct/nn/rim/tests/test_rim.py new file mode 100644 index 00000000..3e5cfba8 --- /dev/null +++ b/direct/nn/rim/tests/test_rim.py @@ -0,0 +1,92 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch + +from direct.data.transforms import fft2, ifft2 +from direct.nn.rim.rim import RIM + + +def create_input(shape): + + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [3, 3, 16, 16], + [2, 5, 16, 32], + ], +) +@pytest.mark.parametrize( + "hidden_channels", + [4, 8], +) +@pytest.mark.parametrize( + "length", + [3], +) +@pytest.mark.parametrize( + "depth", + [1, 2], +) +@pytest.mark.parametrize( + "no_parameter_sharing", + [True, False], +) +@pytest.mark.parametrize( + "instance_norm", + [True, False], +) +@pytest.mark.parametrize( + "dense_connect", + [True, False], +) +@pytest.mark.parametrize( + "skip_connections", + [True, False], +) +@pytest.mark.parametrize( + "image_init", + [ + "zero-filled", + "sense", + "input-kspace", + ], +) +def test_rim( + shape, + hidden_channels, + length, + depth, + no_parameter_sharing, + instance_norm, + dense_connect, + skip_connections, + image_init, +): + model = RIM( + fft2, + ifft2, + hidden_channels=hidden_channels, + length=length, + depth=depth, + no_parameter_sharing=no_parameter_sharing, + instance_norm=instance_norm, + dense_connect=dense_connect, + skip_connections=skip_connections, + image_initialization=image_init, + ).cpu() + + img = create_input([shape[0]] + shape[2:] + [2]).cpu() + kspace = create_input(shape + [2]).cpu() + sens = create_input(shape + [2]).cpu() + mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu() + + out = model(img, kspace, mask, sens)[0][-1] + + assert list(out.shape) == [shape[0]] + [2] + shape[2:] diff --git a/direct/nn/unet/config.py b/direct/nn/unet/config.py index 4e3f7683..aa3e46d8 100644 --- a/direct/nn/unet/config.py +++ b/direct/nn/unet/config.py @@ -12,3 +12,13 @@ class UnetModel2dConfig(ModelConfig): num_filters: int = 16 num_pool_layers: int = 4 dropout_probability: float = 0.0 + + +@dataclass +class Unet2dConfig(ModelConfig): + num_filters: int = 16 + num_pool_layers: int = 4 + dropout_probability: float = 0.0 + skip_connection: bool = False + normalized: bool = False + image_initialization: str = "zero_filled" diff --git a/direct/nn/unet/tests/__init__.py b/direct/nn/unet/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/unet/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/unet/tests/test_unet_2d.py b/direct/nn/unet/tests/test_unet_2d.py new file mode 100644 index 00000000..bf912643 --- /dev/null +++ b/direct/nn/unet/tests/test_unet_2d.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + + +import numpy as np +import pytest +import torch + +from direct.data.transforms import fft2, ifft2 +from direct.nn.unet.unet_2d import Unet2d, NormUnetModel2d + + +def create_input(shape): + data = np.random.randn(*shape).copy() + data = torch.from_numpy(data).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [2, 3, 16, 16], + [4, 5, 16, 32], + [3, 4, 32, 32], + [3, 4, 40, 20], + ], +) +@pytest.mark.parametrize( + "num_filters", + [4, 6, 8], +) +@pytest.mark.parametrize( + "num_pool_layers", + [2, 3], +) +@pytest.mark.parametrize( + "skip", + [True, False], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_unet_2d(shape, num_filters, num_pool_layers, skip, normalized): + model = Unet2d( + fft2, + ifft2, + num_filters=num_filters, + num_pool_layers=num_pool_layers, + skip_connection=skip, + normalized=normalized, + dropout_probability=0.05, + ).cpu() + + data = create_input(shape + [2]).cpu() + sens = create_input(shape + [2]).cpu() + + out = model(data, sens) + + assert list(out.shape) == [shape[0]] + shape[2:] + [2] diff --git a/direct/nn/unet/unet_2d.py b/direct/nn/unet/unet_2d.py index 859c1916..18d27c88 100644 --- a/direct/nn/unet/unet_2d.py +++ b/direct/nn/unet/unet_2d.py @@ -1,12 +1,16 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -# Code borrowed / edited from: https://github.com/facebookresearch/fastMRI/blob/master/models/unet/unet_model.py +# Code borrowed / edited from: https://github.com/facebookresearch/fastMRI/blob/ +import math +from typing import Callable, List, Optional, Tuple import torch from torch import nn from torch.nn import functional as F +from direct.data import transforms as T + class ConvBlock(nn.Module): """ @@ -156,7 +160,7 @@ def __init__( self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() - for i in range(num_pool_layers - 1): + for _ in range(num_pool_layers - 1): self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)] self.up_conv += [ConvBlock(ch * 2, ch, dropout_probability)] ch //= 2 @@ -209,3 +213,228 @@ def forward(self, input: torch.Tensor): output = conv(output) return output + + +class NormUnetModel2d(nn.Module): + """ + Implementation of a Normalized U-Net model. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_filters: int, + num_pool_layers: int, + dropout_probability: float, + norm_groups: int = 2, + ): + """ + + Parameters + ---------- + in_channels : int + Number of input channels to the u-net. + out_channels : int + Number of output channels to the u-net. + num_filters : int + Number of output channels of the first convolutional layer. + num_pool_layers : int + Number of down-sampling and up-sampling layers (depth). + dropout_probability : float + Dropout probability. + norm_groups : int, + Number of normalization groups. + """ + super().__init__() + + self.unet2d = UnetModel2d( + in_channels=in_channels, + out_channels=out_channels, + num_filters=num_filters, + num_pool_layers=num_pool_layers, + dropout_probability=dropout_probability, + ) + + self.norm_groups = norm_groups + + @staticmethod + def norm(input: torch.Tensor, groups: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # group norm + b, c, h, w = input.shape + input = input.reshape(b, groups, -1) + + mean = input.mean(-1, keepdim=True) + std = input.std(-1, keepdim=True) + + output = (input - mean) / std + output = output.reshape(b, c, h, w) + + return output, mean, std + + @staticmethod + def unnorm(input: torch.Tensor, mean: torch.Tensor, std: torch.Tensor, groups: int) -> torch.Tensor: + b, c, h, w = input.shape + input = input.reshape(b, groups, -1) + return (input * std + mean).reshape(b, c, h, w) + + @staticmethod + def pad(input: torch.Tensor) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: + _, _, h, w = input.shape + w_mult = ((w - 1) | 15) + 1 + h_mult = ((h - 1) | 15) + 1 + w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] + h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] + + output = F.pad(input, w_pad + h_pad) + return output, (h_pad, w_pad, h_mult, w_mult) + + @staticmethod + def unpad( + input: torch.Tensor, + h_pad: List[int], + w_pad: List[int], + h_mult: int, + w_mult: int, + ) -> torch.Tensor: + + return input[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + + Parameters + ---------- + input : torch.Tensor + + Returns + ------- + torch.Tensor + """ + + output, mean, std = self.norm(input, self.norm_groups) + output, pad_sizes = self.pad(output) + output = self.unet2d(output) + + output = self.unpad(output, *pad_sizes) + output = self.unnorm(output, mean, std, self.norm_groups) + + return output + + +class Unet2d(nn.Module): + """ + PyTorch implementation of a U-Net model for MRI Reconstruction. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + num_filters: int, + num_pool_layers: int, + dropout_probability: float, + skip_connection: bool = False, + normalized: bool = False, + image_initialization: str = "zero_filled", + **kwargs, + ): + """ + + Parameters + ---------- + forward_operator : Callable + Forward Operator. + backward_operator : Callable + Backward Operator. + num_filters : int + Number of first layer filters. + num_pool_layers : int + Number of pooling layers. + dropout_probability : float + Dropout probability. + skip_connection : bool + If True, skip connection is used for the output. Default: False. + normalized : bool + If True, Normalized Unet is used. Default: False. + image_initialization : str + Type of image initialization. Default: "zero-filled". + kwargs: dict + """ + super().__init__() + extra_keys = kwargs.keys() + for extra_key in extra_keys: + if extra_key not in [ + "sensitivity_map_model", + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + if normalized: + self.unet = NormUnetModel2d( + in_channels=2, + out_channels=2, + num_filters=num_filters, + num_pool_layers=num_pool_layers, + dropout_probability=dropout_probability, + ) + else: + self.unet = UnetModel2d( + in_channels=2, + out_channels=2, + num_filters=num_filters, + num_pool_layers=num_pool_layers, + dropout_probability=dropout_probability, + ) + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.skip_connection = skip_connection + self.image_initialization = image_initialization + self._coil_dim = 1 + self._spatial_dims = (2, 3) + + def compute_sense_init(self, kspace, sensitivity_map): + input_image = T.complex_multiplication( + T.conjugate(sensitivity_map), + self.backward_operator(kspace, dim=self._spatial_dims), + ) + input_image = input_image.sum(self._coil_dim) + return input_image + + def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map : torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). Default: None. + + Returns + ------- + torch.Tensor + Output image of shape (N, height, width, complex=2). + """ + if self.image_initialization == "sense": + if sensitivity_map is None: + raise ValueError("Expected sensitivity_map not to be None with 'sense' image_initialization.") + input_image = self.compute_sense_init( + kspace=masked_kspace, + sensitivity_map=sensitivity_map, + ) + elif self.image_initialization == "zero_filled": + input_image = self.backward_operator(masked_kspace).sum(self._coil_dim) + else: + raise ValueError( + f"Unknown image_initialization. Expected `sense` or `zero_filled`. " + f"Got {self.image_initialization}." + ) + + output = self.unet(input_image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + if self.skip_connection: + output += input_image + return output diff --git a/direct/nn/unet/unet_engine.py b/direct/nn/unet/unet_engine.py new file mode 100644 index 00000000..2a4b7ef8 --- /dev/null +++ b/direct/nn/unet/unet_engine.py @@ -0,0 +1,469 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import time +from collections import defaultdict +from os import PathLike +from typing import Callable, DefaultDict, Dict, List, Optional + +import numpy as np +import torch +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader + +import direct.data.transforms as T +from direct.config import BaseConfig +from direct.engine import DoIterationOutput, Engine +from direct.functionals import SSIMLoss +from direct.utils import ( + communication, + detach_dict, + dict_to_device, + merge_list_of_dicts, + multiply_function, + reduce_list_of_dicts, +) +from direct.utils.communication import reduce_tensor_dict + + +class Unet2dEngine(Engine): + """ + Unet2d Model Engine. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: int, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._complex_dim = -1 + self._coil_dim = 1 + self._spatial_dims = (2, 3) + + def _do_iteration( + self, + data: Dict[str, torch.Tensor], + loss_fns: Optional[Dict[str, Callable]] = None, + regularizer_fns: Optional[Dict[str, Callable]] = None, + ) -> DoIterationOutput: + + # loss_fns can be done, e.g. during validation + if loss_fns is None: + loss_fns = {} + + if regularizer_fns is None: + regularizer_fns = {} + + loss_dicts = [] + regularizer_dicts = [] + + data = dict_to_device(data, self.device) + + if self.cfg.model.image_initialization == "sense": + # sensitivity_map of shape (batch, coil, height, width, complex=2) + sensitivity_map = data["sensitivity_map"] + + # Some things can be done with the sensitivity map here, e.g. apply a u-net + if "sensitivity_model" in self.models: + # Move channels to first axis + sensitivity_map = data["sensitivity_map"].permute( + (0, 1, 4, 2, 3) + ) # shape (batch, coil, complex=2, height, width) + + sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute( + (0, 1, 3, 4, 2) + ) # has channel last: shape (batch, coil, height, width, complex=2) + + # The sensitivity map needs to be normalized such that + # So \sum_{i \in \text{coils}} S_i S_i^* = 1 + + sensitivity_map_norm = torch.sqrt(((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim)) + # shape (batch, 1, height, width, 1) + sensitivity_map_norm = sensitivity_map_norm.unsqueeze(self._coil_dim).unsqueeze(self._complex_dim) + data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm) + + with autocast(enabled=self.mixed_precision): + + output_image = self.model( + masked_kspace=data["masked_kspace"], + sensitivity_map=data["sensitivity_map"] if self.cfg.model.image_initialization == "sense" else None, + ) + output_image = T.modulus(output_image) + + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + regularizer_dict = { + k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() + } + + for key, value in loss_dict.items(): + loss_dict[key] = value + loss_fns[key]( + output_image, + **data, + reduction="mean", + ) + + for key, value in regularizer_dict.items(): + regularizer_dict[key] = value + regularizer_fns[key]( + output_image, + **data, + ) + + loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dicts.append(detach_dict(loss_dict)) + regularizer_dicts.append( + detach_dict(regularizer_dict) + ) # Need to detach dict as this is only used for logging. + + # Add the loss dicts. + loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") + regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict, **regularizer_dict}, + ) + + def build_loss(self, **kwargs) -> Dict: + # TODO: Cropper is a processing output tool. + def get_resolution(**data): + """Be careful that this will use the cropping size of the FIRST sample in the batch.""" + return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None)) + + def l1_loss(source, reduction="mean", **data): + """ + Calculate L1 loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l1_loss + + def l2_loss(source, reduction="mean", **data): + """ + Calculate L2 loss (MSE) given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l2_loss + + def ssim_loss(source, reduction="mean", **data): + """ + Calculate SSIM loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + if reduction != "mean": + raise AssertionError( + f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}." + ) + + source_abs, target_abs = self.cropper(source, data["target"], resolution) + data_range = torch.tensor([target_abs.max()], device=target_abs.device) + + ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range) + + return ssim_loss + + # Build losses + loss_dict = {} + for curr_loss in self.cfg.training.loss.losses: # type: ignore + loss_fn = curr_loss.function + if loss_fn == "l1_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss) + elif loss_fn == "l2_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss) + elif loss_fn == "ssim_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss) + else: + raise ValueError(f"{loss_fn} not permissible.") + + return loss_dict + + @torch.no_grad() + def evaluate( + self, + data_loader: DataLoader, + loss_fns: Optional[Dict[str, Callable]], + regularizer_fns: Optional[Dict[str, Callable]] = None, + crop: Optional[str] = None, + is_validation_process: bool = True, + ): + """ + Validation process. Assumes that each batch only contains slices of the same volume *AND* that these + are sequentially ordered. + + Parameters + ---------- + data_loader : DataLoader + loss_fns : Dict[str, Callable], optional + regularizer_fns : Dict[str, Callable], optional + crop : str, optional + is_validation_process : bool + + Returns + ------- + loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + """ + self.models_to_device() + self.models_validation_mode() + torch.cuda.empty_cache() + + # Variables required for evaluation. + volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore + + # filenames can be in the volume_indices attribute of the dataset + num_for_this_process = None + all_filenames = None + if hasattr(data_loader.dataset, "volume_indices"): + all_filenames = list(data_loader.dataset.volume_indices.keys()) + num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys())) + self.logger.info( + f"Reconstructing a total of {len(all_filenames)} volumes. " + f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})." + ) + + filenames_seen = 0 + reconstruction_output: DefaultDict = defaultdict(list) + if is_validation_process: + targets_output: DefaultDict = defaultdict(list) + val_losses = [] + val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict) + last_filename = None + + # Container to for the slices which can be visualized in TensorBoard. + visualize_slices: List[np.ndarray] = [] + visualize_target: List[np.ndarray] = [] + # visualizations = {} + + extra_visualization_keys = ( + self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore + ) + + # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler + # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is + # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only + # contains data from one volume. + time_start = time.time() + + for iter_idx, data in enumerate(data_loader): + filenames = data.pop("filename") + if len(set(filenames)) != 1: + raise ValueError( + f"Expected a batch during validation to only contain filenames of one case. " + f"Got {set(filenames)}." + ) + + slice_nos = data.pop("slice_no") + scaling_factors = data["scaling_factor"] + + resolution = self.compute_resolution( + key=self.cfg.validation.crop, # type: ignore + reconstruction_size=data.get("reconstruction_size", None), + ) + + # Compute output and loss. + iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns) + output = iteration_output.output_image + loss_dict = iteration_output.data_dict + + loss_dict = detach_dict(loss_dict) + output = output.detach() + val_losses.append(loss_dict) + + # Output is complex-valued, and has to be cropped. This holds for both output and target. + # Output has shape (batch, complex, height, width) + output_abs = self.process_output( + output, + scaling_factors, + resolution=resolution, + ) + + if is_validation_process: + # Target has shape (batch, height, width) + target_abs = self.process_output( + data["target"].detach(), + scaling_factors, + resolution=resolution, + ) + for key in extra_visualization_keys: + curr_data = data[key].detach() + # Here we need to discover which keys are actually normalized or not + # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23 + + del output # Explicitly call delete to clear memory. + + # Aggregate volumes to be able to compute the metrics on complete volumes. + for idx, filename in enumerate(filenames): + if last_filename is None: + last_filename = filename # First iteration last_filename is not set. + + curr_slice = output_abs[idx].detach() + slice_no = int(slice_nos[idx].numpy()) + + reconstruction_output[filename].append((slice_no, curr_slice.cpu())) + + if is_validation_process: + targets_output[filename].append((slice_no, target_abs[idx].cpu())) + + is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"]) + reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch] + for condition in reconstruction_conditions: + if condition: + filenames_seen += 1 + + # Now we can ditch the reconstruction dict by reconstructing the volume, + # will take too much memory otherwise. + volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]]) + if is_validation_process: + target = torch.stack([_[1] for _ in targets_output[last_filename]]) + curr_metrics = { + metric_name: metric_fn(target, volume) + for metric_name, metric_fn in volume_metrics.items() + } + val_volume_metrics[last_filename] = curr_metrics + # Log the center slice of the volume + if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore + visualize_slices.append(volume[volume.shape[0] // 2]) + visualize_target.append(target[target.shape[0] // 2]) + + # Delete outputs from memory, and recreate dictionary. + # This is not needed when not in validation as we are actually interested + # in the iteration output. + del targets_output[last_filename] + del reconstruction_output[last_filename] + + if all_filenames: + log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:" + else: + log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:" + + self.logger.info( + f"{log_prefix} {last_filename}" + f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s." + ) + # restart timer + time_start = time.time() + last_filename = filename + + # Average loss dict + loss_dict = reduce_list_of_dicts(val_losses) + reduce_tensor_dict(loss_dict) + + communication.synchronize() + torch.cuda.empty_cache() + + all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics)) + if not is_validation_process: + return loss_dict, reconstruction_output + + return loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + def process_output(self, data, scaling_factors=None, resolution=None): + # data is of shape (batch, complex=2, height, width) + if scaling_factors is not None: + data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device) + + data = T.modulus_if_complex(data) + + if len(data.shape) == 3: # (batch, height, width) + data = data.unsqueeze(1) # Added channel dimension. + + if resolution is not None: + data = T.center_crop(data, resolution).contiguous() + + return data + + @staticmethod + def compute_resolution(key, reconstruction_size): + if key == "header": + # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over + # batches. + resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size] + # The volume sampler should give validation indices belonging to the *same* volume, so it should be + # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y). + resolution = [_[0] for _ in resolution][:-1] + elif key == "training": + resolution = key + elif not key: + resolution = None + else: + raise ValueError( + "Cropping should be either set to `header` to get the values from the header or " + "`training` to take the same value as training." + ) + return resolution + + def cropper(self, source, target, resolution): + """ + 2D source/target cropper + + Parameters: + ----------- + Source has shape (batch, height, width) + Target has shape (batch, height, width) + + """ + + if not resolution or all(_ == 0 for _ in resolution): + return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension. + + source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension. + target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension. + + return source_abs, target_abs + + def compute_model_per_coil(self, model_name, data): + """ + Computes model per coil. + """ + # data is of shape (batch, coil, complex=2, height, width) + output = [] + + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + output.append(self.models[model_name](subselected_data)) + output = torch.stack(output, dim=self._coil_dim) + + # output is of shape (batch, coil, complex=2, height, width) + return output diff --git a/direct/nn/varnet/__init__.py b/direct/nn/varnet/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/varnet/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/varnet/config.py b/direct/nn/varnet/config.py new file mode 100644 index 00000000..14cc5e75 --- /dev/null +++ b/direct/nn/varnet/config.py @@ -0,0 +1,13 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors +from dataclasses import dataclass + +from direct.config.defaults import ModelConfig + + +@dataclass +class EndToEndVarNetConfig(ModelConfig): + num_layers: int = 8 + regularizer_num_filters: int = 18 + regularizer_num_pull_layers: int = 4 + regularizer_dropout: float = 0.0 diff --git a/direct/nn/varnet/tests/__init__.py b/direct/nn/varnet/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/varnet/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/varnet/tests/test_varnet.py b/direct/nn/varnet/tests/test_varnet.py new file mode 100644 index 00000000..67cbd59f --- /dev/null +++ b/direct/nn/varnet/tests/test_varnet.py @@ -0,0 +1,43 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch + +from direct.data.transforms import fft2, ifft2 +from direct.nn.varnet.varnet import EndToEndVarNet + + +def create_input(shape): + + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [[4, 3, 32, 32], [4, 5, 40, 20]], +) +@pytest.mark.parametrize( + "num_layers", + [2, 3, 6], +) +@pytest.mark.parametrize( + "num_filters", + [2, 4], +) +@pytest.mark.parametrize( + "num_pull_layers", + [2, 4], +) +def test_varnet(shape, num_layers, num_filters, num_pull_layers): + model = EndToEndVarNet(fft2, ifft2, num_layers, num_filters, num_pull_layers, in_channels=2).cpu() + + kspace = create_input(shape + [2]).cpu() + mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu() + sens = create_input(shape + [2]).cpu() + + out = model(kspace, mask, sens) + + assert list(out.shape) == shape + [2] diff --git a/direct/nn/varnet/varnet.py b/direct/nn/varnet/varnet.py new file mode 100644 index 00000000..00ade925 --- /dev/null +++ b/direct/nn/varnet/varnet.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from typing import Callable + +import torch +import torch.nn as nn + +from direct.data.transforms import expand_operator, reduce_operator +from direct.nn.unet import UnetModel2d + + +class EndToEndVarNet(nn.Module): + """ + End-to-End Variational Network as in https://arxiv.org/abs/2004.06688. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + num_layers: int, + regularizer_num_filters: int = 18, + regularizer_num_pull_layers: int = 4, + regularizer_dropout: float = 0.0, + in_channels: int = 2, + **kwargs, + ): + """ + Parameters: + ----------- + forward_operator : Callable + Forward Operator. + backward_operator : Callable + Backward Operator. + num_layers : int + Number of cascades. + regularizer_num_filters : int + Regularizer model number of filters. + regularizer_num_pull_layers : int + Regularizer model number of pulling layers. + regularizer_dropout : float + Regularizer model dropout probability. + + """ + super().__init__() + extra_keys = kwargs.keys() + for extra_key in extra_keys: + if extra_key not in [ + "model_name", + ]: + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + + self.layers_list = nn.ModuleList() + + for _ in range(num_layers): + self.layers_list.append( + EndToEndVarNetBlock( + forward_operator=forward_operator, + backward_operator=backward_operator, + regularizer_model=UnetModel2d( + in_channels=in_channels, + out_channels=in_channels, + num_filters=regularizer_num_filters, + num_pool_layers=regularizer_num_pull_layers, + dropout_probability=regularizer_dropout, + ), + ) + ) + + def forward( + self, masked_kspace: torch.Tensor, sampling_mask: torch.Tensor, sensitivity_map: torch.Tensor + ) -> torch.Tensor: + """ + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sampling_mask : torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + sensitivity_map : torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). + + Returns + ------- + kspace_prediction : torch.Tensor + K-space prediction of shape (N, coil, height, width, complex=2). + """ + + kspace_prediction = masked_kspace.clone() + for layer in self.layers_list: + kspace_prediction = layer(kspace_prediction, masked_kspace, sampling_mask, sensitivity_map) + return kspace_prediction + + +class EndToEndVarNetBlock(nn.Module): + """ + End-to-End Variational Network block. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + regularizer_model: nn.Module, + ): + """ + + Parameters: + ----------- + forward_operator : Callable + Forward Operator. + backward_operator : Callable + Backward Operator. + regularizer_model : nn.Module + Regularizer model. + """ + super().__init__() + self.regularizer_model = regularizer_model + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.learning_rate = nn.Parameter(torch.tensor([1.0])) + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def forward( + self, + current_kspace: torch.Tensor, + masked_kspace: torch.Tensor, + sampling_mask: torch.Tensor, + sensitivity_map: torch.Tensor, + ) -> torch.Tensor: + """ + + Parameters + ---------- + current_kspace : torch.Tensor + Current k-space prediction of shape (N, coil, height, width, complex=2). + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sampling_mask : torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + sensitivity_map : torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). + + Returns + ------- + torch.Tensor + Next k-space prediction of shape (N, coil, height, width, complex=2). + """ + kspace_error = torch.where( + sampling_mask == 0, + torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), + current_kspace - masked_kspace, + ) + regularization_term = torch.cat( + [ + reduce_operator( + self.backward_operator(kspace, dim=self._spatial_dims), sensitivity_map, dim=self._coil_dim + ) + for kspace in torch.split(current_kspace, 2, self._complex_dim) + ], + dim=self._complex_dim, + ).permute(0, 3, 1, 2) + regularization_term = self.regularizer_model(regularization_term).permute(0, 2, 3, 1) + regularization_term = torch.cat( + [ + self.forward_operator( + expand_operator(image, sensitivity_map, dim=self._coil_dim), dim=self._spatial_dims + ) + for image in torch.split(regularization_term, 2, self._complex_dim) + ], + dim=self._complex_dim, + ) + return current_kspace - self.learning_rate * kspace_error + regularization_term diff --git a/direct/nn/varnet/varnet_engine.py b/direct/nn/varnet/varnet_engine.py new file mode 100644 index 00000000..2323f394 --- /dev/null +++ b/direct/nn/varnet/varnet_engine.py @@ -0,0 +1,473 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import time +from collections import defaultdict +from os import PathLike +from typing import Callable, DefaultDict, Dict, List, Optional + +import numpy as np +import torch +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader + +import direct.data.transforms as T +from direct.config import BaseConfig +from direct.engine import DoIterationOutput, Engine +from direct.functionals import SSIMLoss +from direct.utils import ( + communication, + detach_dict, + dict_to_device, + merge_list_of_dicts, + multiply_function, + reduce_list_of_dicts, +) +from direct.utils.communication import reduce_tensor_dict + + +class EndToEndVarNetEngine(Engine): + """ + End-to-End Variational Network Engine. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: int, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._complex_dim = -1 + self._coil_dim = 1 + self._spatial_dims = (2, 3) + + def _do_iteration( + self, + data: Dict[str, torch.Tensor], + loss_fns: Optional[Dict[str, Callable]] = None, + regularizer_fns: Optional[Dict[str, Callable]] = None, + ) -> DoIterationOutput: + + # loss_fns can be done, e.g. during validation + if loss_fns is None: + loss_fns = {} + + if regularizer_fns is None: + regularizer_fns = {} + + loss_dicts = [] + regularizer_dicts = [] + + data = dict_to_device(data, self.device) + + # sensitivity_map of shape (batch, coil, height, width, complex=2) + sensitivity_map = data["sensitivity_map"] + + if "sensitivity_model" in self.models: + + # Move channels to first axis + sensitivity_map = data["sensitivity_map"].permute( + (0, 1, 4, 2, 3) + ) # shape (batch, coil, complex=2, height, width) + + sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute( + (0, 1, 3, 4, 2) + ) # has channel last: shape (batch, coil, height, width, complex=2) + + # The sensitivity map needs to be normalized such that + # So \sum_{i \in \text{coils}} S_i S_i^* = 1 + + sensitivity_map_norm = torch.sqrt( + ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim) + ) # shape (batch, height, width) + sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1) + data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm) + + with autocast(enabled=self.mixed_precision): + + output_kspace = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + ) + + output_image = T.root_sum_of_squares( + self.backward_operator(output_kspace, dim=self._spatial_dims), dim=self._coil_dim + ) # shape (batch, height, width) + + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + regularizer_dict = { + k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() + } + + for key, value in loss_dict.items(): + loss_dict[key] = value + loss_fns[key]( + output_image, + **data, + reduction="mean", + ) + + for key, value in regularizer_dict.items(): + regularizer_dict[key] = value + regularizer_fns[key]( + output_image, + **data, + ) + + loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dicts.append(detach_dict(loss_dict)) + regularizer_dicts.append( + detach_dict(regularizer_dict) + ) # Need to detach dict as this is only used for logging. + + # Add the loss dicts. + loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") + regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict, **regularizer_dict}, + ) + + def build_loss(self, **kwargs) -> Dict: + # TODO: Cropper is a processing output tool. + def get_resolution(**data): + """Be careful that this will use the cropping size of the FIRST sample in the batch.""" + return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None)) + + def l1_loss(source, reduction="mean", **data): + """ + Calculate L1 loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l1_loss + + def l2_loss(source, reduction="mean", **data): + """ + Calculate L2 loss (MSE) given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l2_loss + + def ssim_loss(source, reduction="mean", **data): + """ + Calculate SSIM loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + if reduction != "mean": + raise AssertionError( + f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}." + ) + + source_abs, target_abs = self.cropper(source, data["target"], resolution) + data_range = torch.tensor([target_abs.max()], device=target_abs.device) + + ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range) + + return ssim_loss + + # Build losses + loss_dict = {} + for curr_loss in self.cfg.training.loss.losses: # type: ignore + loss_fn = curr_loss.function + if loss_fn == "l1_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss) + elif loss_fn == "l2_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss) + elif loss_fn == "ssim_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss) + else: + raise ValueError(f"{loss_fn} not permissible.") + + return loss_dict + + @torch.no_grad() + def evaluate( + self, + data_loader: DataLoader, + loss_fns: Optional[Dict[str, Callable]], + regularizer_fns: Optional[Dict[str, Callable]] = None, + crop: Optional[str] = None, + is_validation_process: bool = True, + ): + """ + Validation process. Assumes that each batch only contains slices of the same volume *AND* that these + are sequentially ordered. + + Parameters + ---------- + data_loader : DataLoader + loss_fns : Dict[str, Callable], optional + regularizer_fns : Dict[str, Callable], optional + crop : str, optional + is_validation_process : bool + + Returns + ------- + loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + """ + self.models_to_device() + self.models_validation_mode() + torch.cuda.empty_cache() + + # Variables required for evaluation. + volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore + + # filenames can be in the volume_indices attribute of the dataset + num_for_this_process = None + all_filenames = None + if hasattr(data_loader.dataset, "volume_indices"): + all_filenames = list(data_loader.dataset.volume_indices.keys()) + num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys())) + self.logger.info( + f"Reconstructing a total of {len(all_filenames)} volumes. " + f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})." + ) + + filenames_seen = 0 + reconstruction_output: DefaultDict = defaultdict(list) + if is_validation_process: + targets_output: DefaultDict = defaultdict(list) + val_losses = [] + val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict) + last_filename = None + + # Container to for the slices which can be visualized in TensorBoard. + visualize_slices: List[np.ndarray] = [] + visualize_target: List[np.ndarray] = [] + # visualizations = {} + + extra_visualization_keys = ( + self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore + ) + + # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler + # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is + # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only + # contains data from one volume. + time_start = time.time() + + for iter_idx, data in enumerate(data_loader): + filenames = data.pop("filename") + if len(set(filenames)) != 1: + raise ValueError( + f"Expected a batch during validation to only contain filenames of one case. " + f"Got {set(filenames)}." + ) + + slice_nos = data.pop("slice_no") + scaling_factors = data["scaling_factor"] + + resolution = self.compute_resolution( + key=self.cfg.validation.crop, # type: ignore + reconstruction_size=data.get("reconstruction_size", None), + ) + + # Compute output and loss. + iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns) + output = iteration_output.output_image + loss_dict = iteration_output.data_dict + + loss_dict = detach_dict(loss_dict) + output = output.detach() + val_losses.append(loss_dict) + + # Output is complex-valued, and has to be cropped. This holds for both output and target. + # Output has shape (batch, complex, height, width) + output_abs = self.process_output( + output, + scaling_factors, + resolution=resolution, + ) + + if is_validation_process: + # Target has shape (batch, height, width) + target_abs = self.process_output( + data["target"].detach(), + scaling_factors, + resolution=resolution, + ) + for key in extra_visualization_keys: + curr_data = data[key].detach() + # Here we need to discover which keys are actually normalized or not + # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23 + + del output # Explicitly call delete to clear memory. + + # Aggregate volumes to be able to compute the metrics on complete volumes. + for idx, filename in enumerate(filenames): + if last_filename is None: + last_filename = filename # First iteration last_filename is not set. + + curr_slice = output_abs[idx].detach() + slice_no = int(slice_nos[idx].numpy()) + + reconstruction_output[filename].append((slice_no, curr_slice.cpu())) + + if is_validation_process: + targets_output[filename].append((slice_no, target_abs[idx].cpu())) + + is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"]) + reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch] + for condition in reconstruction_conditions: + if condition: + filenames_seen += 1 + + # Now we can ditch the reconstruction dict by reconstructing the volume, + # will take too much memory otherwise. + volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]]) + if is_validation_process: + target = torch.stack([_[1] for _ in targets_output[last_filename]]) + curr_metrics = { + metric_name: metric_fn(target, volume) + for metric_name, metric_fn in volume_metrics.items() + } + val_volume_metrics[last_filename] = curr_metrics + # Log the center slice of the volume + if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore + visualize_slices.append(volume[volume.shape[0] // 2]) + visualize_target.append(target[target.shape[0] // 2]) + + # Delete outputs from memory, and recreate dictionary. + # This is not needed when not in validation as we are actually interested + # in the iteration output. + del targets_output[last_filename] + del reconstruction_output[last_filename] + + if all_filenames: + log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:" + else: + log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:" + + self.logger.info( + f"{log_prefix} {last_filename}" + f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s." + ) + # restart timer + time_start = time.time() + last_filename = filename + + # Average loss dict + loss_dict = reduce_list_of_dicts(val_losses) + reduce_tensor_dict(loss_dict) + + communication.synchronize() + torch.cuda.empty_cache() + + all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics)) + if not is_validation_process: + return loss_dict, reconstruction_output + + return loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + def process_output(self, data, scaling_factors=None, resolution=None): + # data is of shape (batch, complex=2, height, width) + if scaling_factors is not None: + data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device) + + data = T.modulus_if_complex(data) + + if len(data.shape) == 3: # (batch, height, width) + data = data.unsqueeze(1) # Added channel dimension. + + if resolution is not None: + data = T.center_crop(data, resolution).contiguous() + + return data + + @staticmethod + def compute_resolution(key, reconstruction_size): + if key == "header": + # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over + # batches. + resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size] + # The volume sampler should give validation indices belonging to the *same* volume, so it should be + # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y). + resolution = [_[0] for _ in resolution][:-1] + elif key == "training": + resolution = key + elif not key: + resolution = None + else: + raise ValueError( + "Cropping should be either set to `header` to get the values from the header or " + "`training` to take the same value as training." + ) + return resolution + + def cropper(self, source, target, resolution): + """ + 2D source/target cropper + + Parameters: + ----------- + Source has shape (batch, height, width) + Target has shape (batch, height, width) + + """ + + if not resolution or all(_ == 0 for _ in resolution): + return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension. + + source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension. + target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension. + + return source_abs, target_abs + + def compute_model_per_coil(self, model_name, data): + """ + Computes model per coil. + """ + # data is of shape (batch, coil, complex=2, height, width) + output = [] + + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + output.append(self.models[model_name](subselected_data)) + output = torch.stack(output, dim=self._coil_dim) + + # output is of shape (batch, coil, complex=2, height, width) + return output diff --git a/direct/nn/xpdnet/__init__.py b/direct/nn/xpdnet/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/xpdnet/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/xpdnet/config.py b/direct/nn/xpdnet/config.py new file mode 100644 index 00000000..0bf070a3 --- /dev/null +++ b/direct/nn/xpdnet/config.py @@ -0,0 +1,26 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from dataclasses import dataclass + +from direct.config.defaults import ModelConfig + + +@dataclass +class XPDNetConfig(ModelConfig): + num_primal: int = 5 + num_dual: int = 1 + num_iter: int = 10 + use_primal_only: bool = True + kspace_model_architecture: str = "CONV" + dual_conv_hidden_channels: int = 16 + dual_conv_n_convs: int = 4 + dual_conv_batchnorm: bool = False + dual_didn_hidden_channels: int = 64 + dual_didn_num_dubs: int = 6 + dual_didn_num_convs_recon: int = 9 + mwcnn_hidden_channels: int = 16 + mwcnn_num_scales: int = 4 + mwcnn_bias: bool = True + mwcnn_batchnorm: bool = False + normalize: bool = False diff --git a/direct/nn/xpdnet/tests/__init__.py b/direct/nn/xpdnet/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/nn/xpdnet/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/nn/xpdnet/tests/test_xpdnet.py b/direct/nn/xpdnet/tests/test_xpdnet.py new file mode 100644 index 00000000..7fb1aaf7 --- /dev/null +++ b/direct/nn/xpdnet/tests/test_xpdnet.py @@ -0,0 +1,76 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest +import torch + +from direct.data.transforms import fft2, ifft2 +from direct.nn.xpdnet.xpdnet import XPDNet + + +def create_input(shape): + + data = torch.rand(shape).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [3, 3, 32, 32], + ], +) +@pytest.mark.parametrize( + "num_iter", + [2, 3], +) +@pytest.mark.parametrize( + "num_primal", + [2, 3], +) +@pytest.mark.parametrize( + "image_model_architecture", + ["MWCNN"], +) +@pytest.mark.parametrize( + "primal_only, kspace_model_architecture, num_dual", + [ + [True, None, 1], + [False, "CONV", 3], + [False, "DIDN", 2], + ], +) +@pytest.mark.parametrize( + "normalize", + [True, False], +) +def test_xpdnet( + shape, + num_iter, + num_primal, + num_dual, + image_model_architecture, + kspace_model_architecture, + primal_only, + normalize, +): + model = XPDNet( + fft2, + ifft2, + num_iter=num_iter, + num_primal=num_primal, + num_dual=num_dual, + image_model_architecture=image_model_architecture, + kspace_model_architecture=kspace_model_architecture, + use_primal_only=primal_only, + normalize=normalize, + ).cpu() + + kspace = create_input(shape + [2]).cpu() + sens = create_input(shape + [2]).cpu() + mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu() + + out = model(kspace, mask, sens) + + assert list(out.shape) == [shape[0]] + shape[2:] + [2] diff --git a/direct/nn/xpdnet/xpdnet.py b/direct/nn/xpdnet/xpdnet.py new file mode 100644 index 00000000..a81987df --- /dev/null +++ b/direct/nn/xpdnet/xpdnet.py @@ -0,0 +1,127 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +from typing import Callable, Optional + +import torch.nn as nn + +from direct.nn.crossdomain.crossdomain import CrossDomainNetwork +from direct.nn.crossdomain.multicoil import MultiCoil +from direct.nn.conv.conv import Conv2d +from direct.nn.didn.didn import DIDN +from direct.nn.mwcnn.mwcnn import MWCNN + + +class XPDNet(CrossDomainNetwork): + """ + XPDNet as implemented in https://arxiv.org/abs/2010.07290. + """ + + def __init__( + self, + forward_operator: Callable, + backward_operator: Callable, + num_primal: int = 5, + num_dual: int = 1, + num_iter: int = 10, + use_primal_only: bool = True, + image_model_architecture: str = "MWCNN", + kspace_model_architecture: Optional[str] = None, + normalize: bool = False, + **kwargs, + ): + """ + + Parameters + ---------- + forward_operator : Callable + Forward Operator. + backward_operator : Callable + Backward Operator. + num_primal : int + Number of primal networks. + num_dual : int + Number of dual networks. + num_iter : int + Number of unrolled iterations. + use_primal_only : bool + If set to True no dual-kspace model is used. Default: True. + image_model_architecture : str + Primal-image model architecture. Currently only implemented for MWCNN. Default: 'MWCNN'. + kspace_model_architecture : str + Dual-kspace model architecture. Currently only implemented for CONV and DIDN. + normalize : bool + Normalize input. Default: False. + kwargs : dict + Keyword arguments for model architectures. + """ + if use_primal_only: + kspace_model_list = None + num_dual = 1 + elif kspace_model_architecture == "CONV": + kspace_model_list = nn.ModuleList( + [ + MultiCoil( + Conv2d( + 2 * (num_dual + num_primal + 1), + 2 * num_dual, + kwargs.get("dual_conv_hidden_channels", 16), + kwargs.get("dual_conv_n_convs", 4), + batchnorm=kwargs.get("dual_conv_batchnorm", False), + ) + ) + for _ in range(num_iter) + ] + ) + elif kspace_model_architecture == "DIDN": + kspace_model_list = nn.ModuleList( + [ + MultiCoil( + DIDN( + in_channels=2 * (num_dual + num_primal + 1), + out_channels=2 * num_dual, + hidden_channels=kwargs.get("dual_didn_hidden_channels", 16), + num_dubs=kwargs.get("dual_didn_num_dubs", 6), + num_convs_recon=kwargs.get("dual_didn_num_convs_recon", 9), + ) + ) + for _ in range(num_iter) + ] + ) + + else: + raise NotImplementedError( + f"XPDNet is currently implemented for kspace_model_architecture == 'CONV' or 'DIDN'." + f"Got kspace_model_architecture == {kspace_model_architecture}." + ) + if image_model_architecture == "MWCNN": + image_model_list = nn.ModuleList( + [ + nn.Sequential( + MWCNN( + input_channels=2 * (num_primal + num_dual), + first_conv_hidden_channels=kwargs.get("mwcnn_hidden_channels", 32), + num_scales=kwargs.get("mwcnn_num_scales", 4), + bias=kwargs.get("mwcnn_bias", False), + batchnorm=kwargs.get("mwcnn_batchnorm", False), + ), + nn.Conv2d(2 * (num_primal + num_dual), 2 * num_primal, kernel_size=3, padding=1), + ) + for _ in range(num_iter) + ] + ) + else: + raise NotImplementedError( + f"XPDNet is currently implemented only with image_model_architecture == 'MWCNN'." + f"Got {image_model_architecture}." + ) + super().__init__( + forward_operator=forward_operator, + backward_operator=backward_operator, + image_model_list=image_model_list, + kspace_model_list=kspace_model_list, + domain_sequence="KI" * num_iter, + image_buffer_size=num_primal, + kspace_buffer_size=num_dual, + normalize_image=normalize, + ) diff --git a/direct/nn/xpdnet/xpdnet_engine.py b/direct/nn/xpdnet/xpdnet_engine.py new file mode 100644 index 00000000..5bae217b --- /dev/null +++ b/direct/nn/xpdnet/xpdnet_engine.py @@ -0,0 +1,472 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import time +from collections import defaultdict +from os import PathLike +from typing import Callable, DefaultDict, Dict, List, Optional + +import numpy as np +import torch +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader + +import direct.data.transforms as T +from direct.config import BaseConfig +from direct.engine import DoIterationOutput, Engine +from direct.functionals import SSIMLoss +from direct.utils import ( + communication, + detach_dict, + dict_to_device, + merge_list_of_dicts, + multiply_function, + reduce_list_of_dicts, +) +from direct.utils.communication import reduce_tensor_dict + + +class XPDNetEngine(Engine): + """ + XPDNet Engine. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: int, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._complex_dim = -1 + self._coil_dim = 1 + self._spatial_dims = (2, 3) + + def _do_iteration( + self, + data: Dict[str, torch.Tensor], + loss_fns: Optional[Dict[str, Callable]] = None, + regularizer_fns: Optional[Dict[str, Callable]] = None, + ) -> DoIterationOutput: + + # loss_fns can be done, e.g. during validation + if loss_fns is None: + loss_fns = {} + + if regularizer_fns is None: + regularizer_fns = {} + + loss_dicts = [] + regularizer_dicts = [] + + data = dict_to_device(data, self.device) + + # sensitivity_map of shape (batch, coil, height, width, complex=2) + sensitivity_map = data["sensitivity_map"] + + if "sensitivity_model" in self.models: + + # Move channels to first axis + sensitivity_map = data["sensitivity_map"].permute( + (0, 1, 4, 2, 3) + ) # shape (batch, coil, complex=2, height, width) + + sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute( + (0, 1, 3, 4, 2) + ) # has channel last: shape (batch, coil, height, width, complex=2) + + # The sensitivity map needs to be normalized such that + # So \sum_{i \in \text{coils}} S_i S_i^* = 1 + + sensitivity_map_norm = torch.sqrt( + ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim) + ) # shape (batch, height, width) + sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1) + data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm) + + with autocast(enabled=self.mixed_precision): + + output_image = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + scaling_factor=data["scaling_factor"], + ) # shape (batch, height, width, complex=2) + + output_image = T.modulus(output_image) # shape (batch, height, width) + + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + regularizer_dict = { + k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() + } + + for key, value in loss_dict.items(): + loss_dict[key] = value + loss_fns[key]( + output_image, + **data, + reduction="mean", + ) + + for key, value in regularizer_dict.items(): + regularizer_dict[key] = value + regularizer_fns[key]( + output_image, + **data, + ) + + loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dicts.append(detach_dict(loss_dict)) + regularizer_dicts.append( + detach_dict(regularizer_dict) + ) # Need to detach dict as this is only used for logging. + + # Add the loss dicts. + loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") + regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict, **regularizer_dict}, + ) + + def build_loss(self, **kwargs) -> Dict: + # TODO: Cropper is a processing output tool. + def get_resolution(**data): + """Be careful that this will use the cropping size of the FIRST sample in the batch.""" + return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None)) + + def l1_loss(source, reduction="mean", **data): + """ + Calculate L1 loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l1_loss + + def l2_loss(source, reduction="mean", **data): + """ + Calculate L2 loss (MSE) given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction) + + return l2_loss + + def ssim_loss(source, reduction="mean", **data): + """ + Calculate SSIM loss given source and target. + + Parameters: + ----------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) + + """ + resolution = get_resolution(**data) + if reduction != "mean": + raise AssertionError( + f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}." + ) + + source_abs, target_abs = self.cropper(source, data["target"], resolution) + data_range = torch.tensor([target_abs.max()], device=target_abs.device) + + ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range) + + return ssim_loss + + # Build losses + loss_dict = {} + for curr_loss in self.cfg.training.loss.losses: # type: ignore + loss_fn = curr_loss.function + if loss_fn == "l1_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss) + elif loss_fn == "l2_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss) + elif loss_fn == "ssim_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss) + else: + raise ValueError(f"{loss_fn} not permissible.") + + return loss_dict + + @torch.no_grad() + def evaluate( + self, + data_loader: DataLoader, + loss_fns: Optional[Dict[str, Callable]], + regularizer_fns: Optional[Dict[str, Callable]] = None, + crop: Optional[str] = None, + is_validation_process: bool = True, + ): + """ + Validation process. Assumes that each batch only contains slices of the same volume *AND* that these + are sequentially ordered. + + Parameters + ---------- + data_loader : DataLoader + loss_fns : Dict[str, Callable], optional + regularizer_fns : Dict[str, Callable], optional + crop : str, optional + is_validation_process : bool + + Returns + ------- + loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + """ + self.models_to_device() + self.models_validation_mode() + torch.cuda.empty_cache() + + # Variables required for evaluation. + volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore + + # filenames can be in the volume_indices attribute of the dataset + num_for_this_process = None + all_filenames = None + if hasattr(data_loader.dataset, "volume_indices"): + all_filenames = list(data_loader.dataset.volume_indices.keys()) + num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys())) + self.logger.info( + f"Reconstructing a total of {len(all_filenames)} volumes. " + f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})." + ) + + filenames_seen = 0 + reconstruction_output: DefaultDict = defaultdict(list) + if is_validation_process: + targets_output: DefaultDict = defaultdict(list) + val_losses = [] + val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict) + last_filename = None + + # Container to for the slices which can be visualized in TensorBoard. + visualize_slices: List[np.ndarray] = [] + visualize_target: List[np.ndarray] = [] + # visualizations = {} + + extra_visualization_keys = ( + self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore + ) + + # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler + # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is + # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only + # contains data from one volume. + time_start = time.time() + + for iter_idx, data in enumerate(data_loader): + filenames = data.pop("filename") + if len(set(filenames)) != 1: + raise ValueError( + f"Expected a batch during validation to only contain filenames of one case. " + f"Got {set(filenames)}." + ) + + slice_nos = data.pop("slice_no") + scaling_factors = data["scaling_factor"] + + resolution = self.compute_resolution( + key=self.cfg.validation.crop, # type: ignore + reconstruction_size=data.get("reconstruction_size", None), + ) + + # Compute output and loss. + iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns) + output = iteration_output.output_image + loss_dict = iteration_output.data_dict + + loss_dict = detach_dict(loss_dict) + output = output.detach() + val_losses.append(loss_dict) + + # Output is complex-valued, and has to be cropped. This holds for both output and target. + # Output has shape (batch, complex, height, width) + output_abs = self.process_output( + output, + scaling_factors, + resolution=resolution, + ) + + if is_validation_process: + # Target has shape (batch, height, width) + target_abs = self.process_output( + data["target"].detach(), + scaling_factors, + resolution=resolution, + ) + for key in extra_visualization_keys: + curr_data = data[key].detach() + # Here we need to discover which keys are actually normalized or not + # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23 + + del output # Explicitly call delete to clear memory. + + # Aggregate volumes to be able to compute the metrics on complete volumes. + for idx, filename in enumerate(filenames): + if last_filename is None: + last_filename = filename # First iteration last_filename is not set. + + curr_slice = output_abs[idx].detach() + slice_no = int(slice_nos[idx].numpy()) + + reconstruction_output[filename].append((slice_no, curr_slice.cpu())) + + if is_validation_process: + targets_output[filename].append((slice_no, target_abs[idx].cpu())) + + is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"]) + reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch] + for condition in reconstruction_conditions: + if condition: + filenames_seen += 1 + + # Now we can ditch the reconstruction dict by reconstructing the volume, + # will take too much memory otherwise. + volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]]) + if is_validation_process: + target = torch.stack([_[1] for _ in targets_output[last_filename]]) + curr_metrics = { + metric_name: metric_fn(target, volume) + for metric_name, metric_fn in volume_metrics.items() + } + val_volume_metrics[last_filename] = curr_metrics + # Log the center slice of the volume + if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore + visualize_slices.append(volume[volume.shape[0] // 2]) + visualize_target.append(target[target.shape[0] // 2]) + + # Delete outputs from memory, and recreate dictionary. + # This is not needed when not in validation as we are actually interested + # in the iteration output. + del targets_output[last_filename] + del reconstruction_output[last_filename] + + if all_filenames: + log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:" + else: + log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:" + + self.logger.info( + f"{log_prefix} {last_filename}" + f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s." + ) + # restart timer + time_start = time.time() + last_filename = filename + + # Average loss dict + loss_dict = reduce_list_of_dicts(val_losses) + reduce_tensor_dict(loss_dict) + + communication.synchronize() + torch.cuda.empty_cache() + + all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics)) + if not is_validation_process: + return loss_dict, reconstruction_output + + return loss_dict, all_gathered_metrics, visualize_slices, visualize_target + + def process_output(self, data, scaling_factors=None, resolution=None): + # data is of shape (batch, complex=2, height, width) + if scaling_factors is not None: + data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device) + + data = T.modulus_if_complex(data) + + if len(data.shape) == 3: # (batch, height, width) + data = data.unsqueeze(1) # Added channel dimension. + + if resolution is not None: + data = T.center_crop(data, resolution).contiguous() + + return data + + @staticmethod + def compute_resolution(key, reconstruction_size): + if key == "header": + # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over + # batches. + resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size] + # The volume sampler should give validation indices belonging to the *same* volume, so it should be + # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y). + resolution = [_[0] for _ in resolution][:-1] + elif key == "training": + resolution = key + elif not key: + resolution = None + else: + raise ValueError( + "Cropping should be either set to `header` to get the values from the header or " + "`training` to take the same value as training." + ) + return resolution + + def cropper(self, source, target, resolution): + """ + 2D source/target cropper + + Parameters: + ----------- + Source has shape (batch, height, width) + Target has shape (batch, height, width) + + """ + + if not resolution or all(_ == 0 for _ in resolution): + return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension. + + source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension. + target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension. + + return source_abs, target_abs + + def compute_model_per_coil(self, model_name, data): + """ + Computes model per coil. + """ + # data is of shape (batch, coil, complex=2, height, width) + output = [] + + for idx in range(data.size(self._coil_dim)): + subselected_data = data.select(self._coil_dim, idx) + output.append(self.models[model_name](subselected_data)) + output = torch.stack(output, dim=self._coil_dim) + + # output is of shape (batch, coil, complex=2, height, width) + return output diff --git a/direct/utils/__init__.py b/direct/utils/__init__.py index 820739d5..f46254f4 100644 --- a/direct/utils/__init__.py +++ b/direct/utils/__init__.py @@ -18,6 +18,52 @@ logger = logging.getLogger(__name__) +def is_complex_data(data: torch.Tensor, complex_last: bool = True) -> bool: + """ + Returns True if data is a complex tensor, i.e. has a complex axis of dimension 2, and False otherwise. + + Parameters + ---------- + data : torch.Tensor + For 2D data the shape is assumed ([batch], [coil], height, width, [complex]) + or ([batch], [coil], [complex], height, width). + For 3D data the shape is assumed ([batch], [coil], slice, height, width, [complex]) + or ([batch], [coil], [complex], slice, height, width). + complex_last : bool + If true, will require complex axis to be at the last axis. + Returns + ------- + + """ + if 2 not in data.shape: + return False + if complex_last: + if data.size(-1) != 2: + return False + else: + if data.ndim == 6: + if data.size(2) != 2 and data.size(-1) != 2: # (B, C, 2, S, H, 2) or (B, C, S, H, W, 2) + return False + + elif data.ndim == 5: + # (B, 2, S, H, W) or (B, C, 2, H, W) or (B, S, H, W, 2) or (B, C, H, W, 2) + if data.size(1) != 2 and data.size(2) != 2 and data.size(-1) != 2: + return False + + elif data.ndim == 4: + if data.size(1) != 2 and data.size(-1) != 2: # (B, 2, H, W) or (B, H, W, 2) or (S, H, W, 2) + return False + + elif data.ndim == 3: + if data.size(-1) != 2: # (H, W, 2) + return False + + else: + raise ValueError(f"Not compatible number of dimensions for complex data. Got {data.ndim}.") + + return True + + def is_power_of_two(number: int) -> bool: """Check if input is a power of 2 @@ -404,9 +450,9 @@ def chunks(list_to_chunk, number_of_chunks): From https://stackoverflow.com/a/54802737 """ d, r = divmod(len(list_to_chunk), number_of_chunks) - for i in range(number_of_chunks): - si = (d + 1) * (i if i < r else r) + d * (0 if i < r else i - r) - yield list_to_chunk[si : si + (d + 1 if i < r else d)] + for idx in range(number_of_chunks): + si = (d + 1) * (idx if idx < r else r) + d * (0 if idx < r else idx - r) + yield list_to_chunk[si : si + (d + 1 if idx < r else d)] def remove_keys(input_dict, keys): diff --git a/direct/utils/asserts.py b/direct/utils/asserts.py index 2c98423e..763ec751 100644 --- a/direct/utils/asserts.py +++ b/direct/utils/asserts.py @@ -3,6 +3,8 @@ import inspect from typing import List +from direct.utils import is_complex_data + import torch @@ -45,7 +47,7 @@ def assert_same_shape(data_list: List[torch.Tensor]): def assert_complex(data: torch.Tensor, complex_last: bool = True) -> None: """ - Assert if a tensor is a complex named tensor. + Assert if a tensor is a complex tensor. Parameters ---------- @@ -61,29 +63,5 @@ def assert_complex(data: torch.Tensor, complex_last: bool = True) -> None: """ # TODO: This is because ifft and fft or torch expect the last dimension to represent the complex axis. - - if 2 not in data.shape: - raise ValueError(f"No complex dimension (2) found. Got shape {data.shape}.") - if complex_last: - if data.size(-1) != 2: - raise ValueError(f"Last dimension assumed to be 2 (complex valued). Got {data.size(-1)}.") - else: - if data.ndim == 6 or data.ndim == 3: - if data.size(1) != 2 and data.size(-1) != 2: - raise ValueError( - f"Complex dimension assumed to be 2 (complex valued), but not found in shape {data.shape}." - ) - - elif data.ndim == 5: - if data.size(1) != 2 and data.size(2) != 2 and data.size(-1) != 2: - raise ValueError( - f"Complex dimension assumed to be 2 (complex valued), but not found in shape {data.shape}." - ) - - elif data.ndim == 4: - if data.size(1) != 2 and data.size(-1) != 2: - raise ValueError( - f"Complex dimension assumed to be 2 (complex valued), but not found in shape {data.shape}." - ) - else: - raise ValueError(f"Data of shape {data.shape} is not complex.") + if not is_complex_data(data, complex_last): + raise ValueError(f"Complex dimension assumed to be 2 (complex valued), but not found in shape {data.shape}.") diff --git a/direct/utils/events.py b/direct/utils/events.py index af5d9220..6a0042f6 100644 --- a/direct/utils/events.py +++ b/direct/utils/events.py @@ -390,7 +390,8 @@ def step(self): correct iteration number. """ self._iter += 1 - self._latest_scalars = {} + # TODO: This clears validation metrics. + # self._latest_scalars = {} @property def vis_data(self): diff --git a/direct/utils/logging.py b/direct/utils/logging.py index 3b5425a5..c2b3b9f7 100644 --- a/direct/utils/logging.py +++ b/direct/utils/logging.py @@ -38,7 +38,7 @@ def setup( root = logging.getLogger() root.setLevel(log_level) - for name in logging.root.manager.loggerDict: # type: ignore + for name in logging.root.manager.loggerDict: # pylint: disable = E1101 # type: ignore if name.startswith("torch"): logging.getLogger(name).setLevel("WARNING") diff --git a/direct/utils/tests/__init__.py b/direct/utils/tests/__init__.py new file mode 100644 index 00000000..941752c9 --- /dev/null +++ b/direct/utils/tests/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors diff --git a/direct/utils/tests/test_utils.py b/direct/utils/tests/test_utils.py new file mode 100644 index 00000000..da0ba2d1 --- /dev/null +++ b/direct/utils/tests/test_utils.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors +"""Tests for the direct.utils module""" + +import numpy as np +import pytest +import torch + +from direct.utils import is_power_of_two, is_complex_data +from direct.data.transforms import tensor_to_complex_numpy + + +def create_input(shape): + data = np.random.randn(*shape).copy() + data = torch.from_numpy(data).float() + + return data + + +@pytest.mark.parametrize( + "shape", + [ + [3, 3, 2], + [5, 8, 4, 2], + [5, 2, 8, 4], + [3, 5, 8, 4, 2], + [3, 5, 2, 8, 4], + [3, 2, 5, 8, 4], + [3, 3, 5, 8, 4, 2], + [3, 3, 2, 5, 8, 4], + ], +) +def test_is_complex_data(shape): + data = create_input(shape) + + assert is_complex_data(data, False) + + +@pytest.mark.parametrize( + "num", + [1, 2, 4, 32, 128, 1024], +) +def test_is_power_of_two(num): + + assert is_power_of_two(num) diff --git a/projects/calgary_campinas/configs/base_jointicnet.yaml b/projects/calgary_campinas/configs/base_jointicnet.yaml new file mode 100644 index 00000000..a97f150d --- /dev/null +++ b/projects/calgary_campinas/configs/base_jointicnet.yaml @@ -0,0 +1,100 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + batch_size: 4 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0005 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 500000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [5] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [10] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: jointicnet.jointicnet.JointICNet + num_iter: 12 + use_norm_unet: True + image_unet_num_filters: 32 + kspace_unet_num_filters: 32 + sens_unet_num_filters: 8 + +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + crop_outer_slices: true + text_description: inference + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace diff --git a/projects/calgary_campinas/configs/base_kikinet.yaml b/projects/calgary_campinas/configs/base_kikinet.yaml new file mode 100644 index 00000000..a657dbaa --- /dev/null +++ b/projects/calgary_campinas/configs/base_kikinet.yaml @@ -0,0 +1,110 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + batch_size: 8 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0002 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 500000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [5] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [10] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: kikinet.kikinet.KIKINet + num_iter: 2 + image_model_architecture: UNET + image_unet_num_filters: 16 + image_unet_num_pool_layers: 4 + kspace_model_architecture: UNET + kspace_unet_num_filters: 16 + kspace_unet_num_pool_layers: 4 + +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 8 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + crop_outer_slices: true + text_description: inference + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace diff --git a/projects/calgary_campinas/configs/base_multidomainnet.yaml b/projects/calgary_campinas/configs/base_multidomainnet.yaml new file mode 100644 index 00000000..ec90112d --- /dev/null +++ b/projects/calgary_campinas/configs/base_multidomainnet.yaml @@ -0,0 +1,105 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + batch_size: 20 + optimizer: Adam + lr: 0.001 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 500000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [5] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [10] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: multidomainnet.multidomainnet.MultiDomainNet + num_filters: 16 + num_pool_layers: 4 + dropout_probability: 0.05 +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 16 + num_pool_layers: 4 + dropout_probability: 0.05 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + crop_outer_slices: true + text_description: inference + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace diff --git a/projects/calgary_campinas/configs/base.yaml b/projects/calgary_campinas/configs/base_rim.yaml similarity index 100% rename from projects/calgary_campinas/configs/base.yaml rename to projects/calgary_campinas/configs/base_rim.yaml diff --git a/projects/calgary_campinas/configs/base_unet.yaml b/projects/calgary_campinas/configs/base_unet.yaml new file mode 100644 index 00000000..ce76b62b --- /dev/null +++ b/projects/calgary_campinas/configs/base_unet.yaml @@ -0,0 +1,106 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + batch_size: 32 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0002 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 3000 + num_iterations: 500000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 2000 + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 + - function: l2_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [5] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [10] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: unet.unet_2d.Unet2d + num_filters: 64 + image_initialization: sense +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 8 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + crop_outer_slices: true + text_description: inference + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace diff --git a/projects/calgary_campinas/configs/base_varnet.yaml b/projects/calgary_campinas/configs/base_varnet.yaml new file mode 100644 index 00000000..f325bd17 --- /dev/null +++ b/projects/calgary_campinas/configs/base_varnet.yaml @@ -0,0 +1,104 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + batch_size: 16 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0005 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1500 + num_iterations: 500000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 + loss: + crop: null + losses: + - function: ssim_loss + multiplier: 1.0 + - function: l1_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [5] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [10] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: varnet.varnet.EndToEndVarNet + num_layers: 12 + regularizer_num_filters: 32 +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 8 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + crop_outer_slices: true + text_description: inference + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace diff --git a/projects/calgary_campinas/configs/base_xpdnet.yaml b/projects/calgary_campinas/configs/base_xpdnet.yaml new file mode 100644 index 00000000..171d48e6 --- /dev/null +++ b/projects/calgary_campinas/configs/base_xpdnet.yaml @@ -0,0 +1,110 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: CalgaryCampinas + accelerations: [5, 10] + crop_outer_slices: true + batch_size: 8 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0002 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 500000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [5] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: CalgaryCampinas + accelerations: [10] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: xpdnet.xpdnet.XPDNet + num_primal: 5 + num_iter: 20 + mwcnn_hidden_channels: 32 + mwcnn_num_scales: 4 + mwcnn_bias: True + mwcnn_batchnorm: False + normalize: False + +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 8 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + crop_outer_slices: true + text_description: inference + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace diff --git a/tools/train_rim.py b/tools/train_model.py similarity index 100% rename from tools/train_rim.py rename to tools/train_model.py