diff --git a/README.md b/README.md
index d72aa4fe..e5e26cf4 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
[![black](https://github.com/directgroup/direct/actions/workflows/black.yml/badge.svg)](https://github.com/directgroup/direct/actions/workflows/black.yml)
# DIRECT
-DIRECT is the Deep Image REConstruction Toolkit that implements state-of-the-art inverse problem solvers. It includes
+DIRECT is the Deep Image REConstruction Toolkit that implements state-of-the-art inverse problem solvers. It stores
inverse problem solvers such as the Learned Primal Dual algorithm and Recurrent Inference Machine, which were part of the winning solution in Facebook & NYUs FastMRI challenge in 2019 and the Calgary-Campinas MRI reconstruction challenge at MIDL 2020.
@@ -29,9 +29,9 @@ If you use DIRECT in your own research, or want to refer to baseline results pub
```BibTeX
@misc{DIRECTTOOLKIT,
- author = {Jonas Teuwen, Nikita Moriakov, Dimitrios Karkalousos, Matthan Caan, George Yiasemis},
- title = {DIRECT},
+ author = {Teuwen, Jonas and Yiasemis, George and Moriakov, Nikita and Karkalousos, Dimitrios and Caan, Matthan},
+ title = {DIRECT: Deep Image REConstruction Toolkit},
howpublished = {\url{https://github.com/directgroup/direct}},
- year = {2020}
+ year = {2021}
}
```
diff --git a/direct/data/h5_data.py b/direct/data/h5_data.py
index e6489ee1..68172b4d 100644
--- a/direct/data/h5_data.py
+++ b/direct/data/h5_data.py
@@ -124,7 +124,7 @@ def parse_filenames_data(self, filenames, extra_h5s=None, filter_slice=None):
if len(filenames) < 5 or idx % (len(filenames) // 5) == 0 or len(filenames) == (idx + 1):
self.logger.info(f"Parsing: {(idx + 1) / len(filenames) * 100:.2f}%.")
try:
- kspace = h5py.File(filename, "r")["kspace"]
+ kspace = np.array(h5py.File(filename, "r")["kspace"])
self.verify_extra_h5_integrity(filename, kspace.shape, extra_h5s=extra_h5s)
except OSError as exc:
diff --git a/direct/data/tests/test_generator.py b/direct/data/tests/test_fake.py
similarity index 100%
rename from direct/data/tests/test_generator.py
rename to direct/data/tests/test_fake.py
diff --git a/direct/data/tests/test_transforms.py b/direct/data/tests/test_transforms.py
index 3343cae5..bbf1c86b 100644
--- a/direct/data/tests/test_transforms.py
+++ b/direct/data/tests/test_transforms.py
@@ -107,6 +107,24 @@ def test_modulus(shape):
assert np.allclose(out_torch, out_numpy)
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [3, 3],
+ [4, 6],
+ [10, 8, 4],
+ [3, 4, 3, 5],
+ ],
+)
+@pytest.mark.parametrize("complex", [True, False])
+def test_modulus_if_complex(shape, complex):
+ if complex:
+ shape += [
+ 2,
+ ]
+ test_modulus(shape)
+
+
@pytest.mark.parametrize(
"shape, dims",
[
@@ -166,7 +184,6 @@ def test_center_crop(shape, target_shape):
[[8, 4], [4, 4]],
],
)
-# @pytest.mark.parametrize("named", [True, False])
def test_complex_center_crop(shape, target_shape):
shape = shape + [2]
input = create_input(shape)
@@ -262,3 +279,53 @@ def test_ifftshift(shape):
out_torch = transforms.ifftshift(torch_tensor).numpy()
out_numpy = np.fft.ifftshift(data)
assert np.allclose(out_torch, out_numpy)
+
+
+@pytest.mark.parametrize(
+ "shape, dim",
+ [
+ [[3, 4, 5], 0],
+ [[3, 3, 4, 5], 1],
+ [[3, 6, 4, 5], 0],
+ [[3, 3, 6, 4, 5], 1],
+ ],
+)
+def test_expand_operator(shape, dim):
+ shape = shape + [
+ 2,
+ ]
+ data = create_input(shape) # noqa
+ shape = shape[:dim] + shape[dim + 1 :]
+ sens = create_input(shape) # noqa
+
+ out_torch = tensor_to_complex_numpy(transforms.expand_operator(data, sens, dim))
+
+ input_numpy = np.expand_dims(tensor_to_complex_numpy(data), dim)
+ input_sens_numpy = tensor_to_complex_numpy(sens)
+ out_numpy = input_sens_numpy * input_numpy
+
+ assert np.allclose(out_torch, out_numpy)
+
+
+@pytest.mark.parametrize(
+ "shape, dim",
+ [
+ [[3, 4, 5], 0],
+ [[3, 3, 4, 5], 1],
+ [[3, 6, 4, 5], 0],
+ [[3, 3, 6, 4, 5], 1],
+ ],
+)
+def test_reduce_operator(shape, dim):
+ shape = shape + [
+ 2,
+ ]
+ coil_data = create_input(shape) # noqa
+ sens = create_input(shape) # noqa
+ out_torch = tensor_to_complex_numpy(transforms.reduce_operator(coil_data, sens, dim))
+
+ input_numpy = tensor_to_complex_numpy(coil_data)
+ input_sens_numpy = tensor_to_complex_numpy(sens)
+ out_numpy = (input_sens_numpy.conj() * input_numpy).sum(dim)
+
+ assert np.allclose(out_torch, out_numpy)
diff --git a/direct/data/transforms.py b/direct/data/transforms.py
index e0730e08..ad72d7e6 100644
--- a/direct/data/transforms.py
+++ b/direct/data/transforms.py
@@ -10,10 +10,9 @@
import numpy as np
import torch
import torch.fft
-from packaging import version
from direct.data.bbox import crop_to_bbox
-from direct.utils import ensure_list, is_power_of_two
+from direct.utils import ensure_list, is_power_of_two, is_complex_data
from direct.utils.asserts import assert_complex, assert_same_shape
@@ -222,7 +221,6 @@ def safe_divide(input_tensor: torch.Tensor, other_tensor: torch.Tensor) -> torch
torch.tensor([0.0], dtype=input_tensor.dtype).to(input_tensor.device),
input_tensor / other_tensor,
)
-
return data
@@ -266,8 +264,7 @@ def align_as(input_tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
input_shape = list(input_tensor.shape)
other_shape = torch.tensor(other.shape, dtype=int)
out_shape = torch.ones(len(other.shape), dtype=int)
- # TODO(gy): Fix to ensure complex_last when [2,..., 2] or [..., N,..., N,...] in other.shape,
- # "-input_shape.count(dim):" is a hack and might cause problems.
+
for dim in np.sort(np.unique(input_tensor.shape)):
ind = torch.where(other_shape == dim)[0][-input_shape.count(dim) :]
out_shape[ind] = dim
@@ -292,7 +289,6 @@ def modulus(data: torch.Tensor) -> torch.Tensor:
complex_axis = -1 if data.size(-1) == 2 else 1
return (data ** 2).sum(complex_axis).sqrt() # noqa
- # return torch.view_as_complex(data).abs()
def modulus_if_complex(data: torch.Tensor) -> torch.Tensor:
@@ -307,11 +303,9 @@ def modulus_if_complex(data: torch.Tensor) -> torch.Tensor:
-------
torch.Tensor
"""
- # TODO: This can be merged with modulus if the tensor is real.
- try:
+ if is_complex_data(data, complex_last=False):
return modulus(data)
- except ValueError:
- return data
+ return data
def roll(
@@ -436,7 +430,7 @@ def _complex_matrix_multiplication(input_tensor, other_tensor, mult_func):
Parameters
----------
- x : torch.Tensor
+ input_tensor : torch.Tensor
other_tensor : torch.Tensor
mult_func : Callable
Multiplication function e.g. torch.bmm or torch.mm
@@ -466,6 +460,7 @@ def complex_mm(input_tensor, other_tensor):
----------
input_tensor : torch.Tensor
other_tensor : torch.Tensor
+
Returns
-------
torch.Tensor
@@ -501,10 +496,6 @@ def conjugate(data: torch.Tensor) -> torch.Tensor:
-------
torch.Tensor
"""
- # assert_complex(data, complex_last=True)
- # data = torch.view_as_real(
- # torch.view_as_complex(data).conj()
- # )
assert_complex(data, complex_last=True)
data = data.clone() # Clone is required as the data in the next line is changed in-place.
data[..., 1] = data[..., 1] * -1.0
@@ -574,11 +565,12 @@ def tensor_to_complex_numpy(data: torch.Tensor) -> np.ndarray:
return data[..., 0] + 1j * data[..., 1]
-def root_sum_of_squares(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
- r"""
+def root_sum_of_squares(data: torch.Tensor, dim: int = 0, complex_dim: int = -1) -> torch.Tensor:
+ """
Compute the root sum of squares (RSS) transform along a given dimension of the input tensor.
- $$x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}$$
+ .. math::
+ x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}
Parameters
----------
@@ -588,16 +580,16 @@ def root_sum_of_squares(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
dim : int
Coil dimension. Default is 0 as the first dimension is always the coil dimension.
+ complex_dim : int
+ Complex channel dimension. Default is -1. If data not complex this is ignored.
Returns
-------
torch.Tensor : RSS of the input tensor.
"""
- try:
- assert_complex(data, complex_last=True)
- complex_index = -1
- return torch.sqrt((data ** 2).sum(complex_index).sum(dim))
- except ValueError:
- return torch.sqrt((data ** 2).sum(dim))
+ if is_complex_data(data):
+ return torch.sqrt((data ** 2).sum(complex_dim).sum(dim))
+
+ return torch.sqrt((data ** 2).sum(dim))
def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor:
@@ -760,5 +752,71 @@ def complex_random_crop(
if len(output) == 1:
return output[0]
-
return output
+
+
+def reduce_operator(
+ coil_data: torch.Tensor,
+ sensitivity_map: torch.Tensor,
+ dim: int = 0,
+) -> torch.Tensor:
+ """
+ Given zero-filled reconstructions from multiple coils :math: \{x_i\}_{i=1}^{N_c} and coil sensitivity maps
+ :math: \{S_i\}_{i=1}^{N_c} it returns
+ .. math::
+ R(x_1, .., x_{N_c}, S_1, .., S_{N_c}) = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i.
+
+ From paper End-to-End Variational Networks for Accelerated MRI Reconstruction.
+
+ Parameters
+ ----------
+ coil_data : torch.Tensor
+ Zero-filled reconstructions from coils. Should be a complex tensor (with complex dim of size 2).
+ sensitivity_map: torch.Tensor
+ Coil sensitivity maps. Should be complex tensor (with complex dim of size 2).
+ dim: int
+ Coil dimension. Default: 0.
+
+ Returns
+ -------
+ torch.Tensor:
+ Combined individual coil images.
+ """
+
+ assert_complex(coil_data, complex_last=True)
+ assert_complex(sensitivity_map, complex_last=True)
+
+ return complex_multiplication(conjugate(sensitivity_map), coil_data).sum(dim)
+
+
+def expand_operator(
+ data: torch.Tensor,
+ sensitivity_map: torch.Tensor,
+ dim: int = 0,
+) -> torch.Tensor:
+ """
+ Given a reconstructed image x and coil sensitivity maps :math: \{S_i\}_{i=1}^{N_c}, it returns
+ .. math::
+ \Epsilon(x) = (S_1 \times x, .., S_{N_c} \times x) = (x_1, .., x_{N_c}).
+
+ From paper End-to-End Variational Networks for Accelerated MRI Reconstruction.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Image data. Should be a complex tensor (with complex dim of size 2).
+ sensitivity_map: torch.Tensor
+ Coil sensitivity maps. Should be complex tensor (with complex dim of size 2).
+ dim: int
+ Coil dimension. Default: 0.
+
+ Returns
+ -------
+ torch.Tensor:
+ Zero-filled reconstructions from each coil.
+ """
+
+ assert_complex(data, complex_last=True)
+ assert_complex(sensitivity_map, complex_last=True)
+
+ return complex_multiplication(sensitivity_map, data.unsqueeze(dim))
diff --git a/direct/engine.py b/direct/engine.py
index 345c9a7f..533c5a5c 100644
--- a/direct/engine.py
+++ b/direct/engine.py
@@ -11,7 +11,7 @@
import warnings
from abc import ABC, abstractmethod
from collections import namedtuple
-from typing import Callable, Dict, List, Optional, TypedDict, Union
+from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
@@ -278,13 +278,6 @@ def training_loop(
fail_counter = 0
for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
- # 2D data is batched and contains keys:
- # "filename_slice", "slice_no"
- # "sampling_mask" of shape: (batch, 1, height, width, 1)
- # "sensitivity_map" of shape: (batch, coil, height, width, complex=2)
- # "target" of shape: (batch, height, width)
- # "masked_kspace" of shape: (batch, coil, height, width, complex=2)
-
if iter_idx == 0:
self.log_first_training_example_and_model(data)
diff --git a/direct/environment.py b/direct/environment.py
index efebe766..1d3bffd0 100644
--- a/direct/environment.py
+++ b/direct/environment.py
@@ -1,8 +1,6 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
-# pylint: disable = E1101
-
import argparse
import logging
import os
@@ -17,6 +15,7 @@
from torch.utils import collect_env
import direct.utils.logging
+from direct.utils.logging import setup
from direct.config.defaults import DefaultConfig, InferenceConfig, TrainingConfig, ValidationConfig
from direct.utils import communication, count_parameters, str_to_class
@@ -74,7 +73,7 @@ def setup_logging(machine_rank, output_directory, run_name, cfg_filename, cfg, d
# Setup logging
log_file = output_directory / f"log_{machine_rank}_{communication.get_local_rank()}.txt"
- direct.utils.logging.setup(
+ setup(
use_stdout=communication.is_main_process() or debug,
filename=log_file,
log_level=("INFO" if not debug else "DEBUG"),
@@ -245,13 +244,13 @@ def setup_common_environment(
dataset_cfg_from_file = extract_names(cfg_from_file[key].datasets)
for idx, (dataset_name, dataset_config) in enumerate(dataset_cfg_from_file):
cfg_from_file_new[key].datasets[idx] = dataset_config
- cfg[key].datasets.append(load_dataset_config(dataset_name))
+ cfg[key].datasets.append(load_dataset_config(dataset_name)) # pylint: disable = E1136
else:
dataset_name, dataset_config = extract_names(cfg_from_file[key].dataset)
cfg_from_file_new[key].dataset = dataset_config
- cfg[key].dataset = load_dataset_config(dataset_name)
+ cfg[key].dataset = load_dataset_config(dataset_name) # pylint: disable = E1136
- cfg[key] = OmegaConf.merge(cfg[key], cfg_from_file_new[key])
+ cfg[key] = OmegaConf.merge(cfg[key], cfg_from_file_new[key]) # pylint: disable = E1136, E1137
# sys.exit()
# Make configuration read only.
# TODO(jt): Does not work when indexing config lists.
diff --git a/direct/launch.py b/direct/launch.py
index 17f32c42..eb28dc8a 100644
--- a/direct/launch.py
+++ b/direct/launch.py
@@ -139,10 +139,10 @@ def _distributed_worker(
if communication._LOCAL_PROCESS_GROUP is not None:
raise RuntimeError
num_machines = world_size // num_gpus_per_machine
- for i in range(num_machines):
- ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
+ for idx in range(num_machines):
+ ranks_on_i = list(range(idx * num_gpus_per_machine, (idx + 1) * num_gpus_per_machine))
pg = dist.new_group(ranks_on_i)
- if i == machine_rank:
+ if idx == machine_rank:
communication._LOCAL_PROCESS_GROUP = pg
main_func(*args)
diff --git a/direct/nn/conv/__init__.py b/direct/nn/conv/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/conv/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/conv/conv.py b/direct/nn/conv/conv.py
new file mode 100644
index 00000000..2da353e7
--- /dev/null
+++ b/direct/nn/conv/conv.py
@@ -0,0 +1,51 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import torch.nn as nn
+
+
+class Conv2d(nn.Module):
+ """
+ Implementation of a simple cascade of 2D convolutions. If batchnorm is set to True, batch normalization
+ layer is applied after each convolution.
+ """
+
+ def __init__(self, in_channels, out_channels, hidden_channels, n_convs=3, activation=nn.PReLU(), batchnorm=False):
+ """
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels.
+ out_channels : int
+ Number of output channels.
+ hidden_channels : int
+ Number of hidden channels.
+ n_convs : int
+ Number of convolutional layers.
+ activation : nn.Module
+ Activation function.
+ batchnorm : bool
+ If True a batch normalization layer is applied after every convolution.
+ """
+ super().__init__()
+
+ self.conv = []
+ for idx in range(n_convs):
+ self.conv.append(
+ nn.Conv2d(
+ in_channels if idx == 0 else hidden_channels,
+ hidden_channels if idx != n_convs - 1 else out_channels,
+ kernel_size=3,
+ padding=1,
+ )
+ )
+ if batchnorm:
+ self.conv.append(nn.BatchNorm2d(hidden_channels if idx != n_convs - 1 else out_channels, eps=1e-4))
+ if idx != n_convs - 1:
+ self.conv.append(activation)
+ self.conv = nn.Sequential(*self.conv)
+
+ def forward(self, x):
+
+ return self.conv(x)
diff --git a/direct/nn/conv/tests/__init__.py b/direct/nn/conv/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/conv/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/conv/tests/test_conv.py b/direct/nn/conv/tests/test_conv.py
new file mode 100644
index 00000000..072c8581
--- /dev/null
+++ b/direct/nn/conv/tests/test_conv.py
@@ -0,0 +1,52 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import pytest
+import torch
+import torch.nn as nn
+
+from direct.nn.conv.conv import Conv2d
+
+
+def create_input(shape):
+
+ data = torch.rand(shape).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [3, 2, 32, 32],
+ [3, 2, 16, 16],
+ ],
+)
+@pytest.mark.parametrize(
+ "out_channels",
+ [3, 5],
+)
+@pytest.mark.parametrize(
+ "hidden_channels",
+ [16, 8],
+)
+@pytest.mark.parametrize(
+ "n_convs",
+ [2, 4],
+)
+@pytest.mark.parametrize(
+ "act",
+ [nn.ReLU(), nn.PReLU()],
+)
+@pytest.mark.parametrize(
+ "batchnorm",
+ [True, False],
+)
+def test_conv(shape, out_channels, hidden_channels, n_convs, act, batchnorm):
+ model = Conv2d(shape[1], out_channels, hidden_channels, n_convs, act, batchnorm)
+
+ data = create_input(shape).cpu()
+
+ out = model(data)
+
+ assert list(out.shape) == [shape[0]] + [out_channels] + shape[2:]
diff --git a/direct/nn/crossdomain/__init__.py b/direct/nn/crossdomain/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/crossdomain/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/crossdomain/crossdomain.py b/direct/nn/crossdomain/crossdomain.py
new file mode 100644
index 00000000..18129e5a
--- /dev/null
+++ b/direct/nn/crossdomain/crossdomain.py
@@ -0,0 +1,194 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+import direct.data.transforms as T
+
+
+class CrossDomainNetwork(nn.Module):
+ """
+ This performs optimisation in both, k-space ("K") and image ("I") domains according to domain_sequence.
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable,
+ backward_operator: Callable,
+ image_model_list: nn.Module,
+ kspace_model_list: Optional[Union[nn.Module, None]] = None,
+ domain_sequence: str = "KIKI",
+ image_buffer_size: int = 1,
+ kspace_buffer_size: int = 1,
+ normalize_image: bool = False,
+ **kwargs,
+ ):
+ """
+
+ Parameters
+ ----------
+ forward_operator : Callable
+ Forward Operator.
+ backward_operator : Callable
+ Backward Operator.
+ image_model_list : nn.Module
+ Image domain model list.
+ kspace_model_list : Optional[nn.Module]
+ K-space domain model list. If set to None, a correction step is applied. Default: None.
+ domain_sequence : str
+ Domain sequence containing only "K" (k-space domain) and/or "I" (image domain). Default: "KIKI".
+ image_buffer_size : int
+ Image buffer size. Default: 1.
+ kspace_buffer_size : int
+ K-space buffer size. Default: 1.
+ normalize_image : bool
+ If True, input is normalized. Default: False.
+ kwargs : dict
+ Keyword Arguments.
+ """
+ super().__init__()
+
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+
+ domain_sequence = [domain_name for domain_name in domain_sequence.strip()]
+ if not set(domain_sequence).issubset({"K", "I"}):
+ raise ValueError(f"Invalid domain sequence. Got {domain_sequence}. Should only contain 'K' and 'I'.")
+
+ if kspace_model_list is not None:
+ if len(kspace_model_list) != domain_sequence.count("K"):
+ raise ValueError(f"K-space domain steps do not match k-space model list length.")
+
+ if len(image_model_list) != domain_sequence.count("I"):
+ raise ValueError(f"Image domain steps do not match image model list length.")
+
+ self.domain_sequence = domain_sequence
+
+ self.kspace_model_list = kspace_model_list
+ self.kspace_buffer_size = kspace_buffer_size
+
+ self.image_model_list = image_model_list
+ self.image_buffer_size = image_buffer_size
+
+ self.normalize_image = normalize_image
+
+ self._coil_dim = 1
+ self._complex_dim = -1
+ self._spatial_dims = (2, 3)
+
+ def kspace_correction(self, block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map, masked_kspace):
+
+ forward_buffer = [
+ self._forward_operator(
+ image.clone(),
+ sampling_mask,
+ sensitivity_map,
+ )
+ for image in torch.split(image_buffer, 2, self._complex_dim)
+ ]
+
+ forward_buffer = torch.cat(forward_buffer, self._complex_dim)
+ kspace_buffer = torch.cat([kspace_buffer, forward_buffer, masked_kspace], self._complex_dim)
+
+ if self.kspace_model_list is not None:
+ kspace_buffer = self.kspace_model_list[block_idx](kspace_buffer.permute(0, 1, 4, 2, 3)).permute(
+ 0, 1, 3, 4, 2
+ )
+ else:
+ kspace_buffer = kspace_buffer[..., :2] - kspace_buffer[..., 2:4]
+
+ return kspace_buffer
+
+ def image_correction(self, block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map):
+ backward_buffer = [
+ self._backward_operator(kspace.clone(), sampling_mask, sensitivity_map)
+ for kspace in torch.split(kspace_buffer, 2, self._complex_dim)
+ ]
+ backward_buffer = torch.cat(backward_buffer, self._complex_dim)
+
+ image_buffer = torch.cat([image_buffer, backward_buffer], self._complex_dim).permute(0, 3, 1, 2)
+ image_buffer = self.image_model_list[block_idx](image_buffer).permute(0, 2, 3, 1)
+
+ return image_buffer
+
+ def _forward_operator(self, image, sampling_mask, sensitivity_map):
+ forward = torch.where(
+ sampling_mask == 0,
+ torch.tensor([0.0], dtype=image.dtype).to(image.device),
+ self.forward_operator(T.expand_operator(image, sensitivity_map, self._coil_dim), dim=self._spatial_dims),
+ )
+ return forward
+
+ def _backward_operator(self, kspace, sampling_mask, sensitivity_map):
+ backward = T.reduce_operator(
+ self.backward_operator(
+ torch.where(
+ sampling_mask == 0,
+ torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device),
+ kspace,
+ ),
+ self._spatial_dims,
+ ),
+ sensitivity_map,
+ self._coil_dim,
+ )
+ return backward
+
+ def forward(
+ self,
+ masked_kspace: torch.Tensor,
+ sampling_mask: torch.Tensor,
+ sensitivity_map: torch.Tensor,
+ scaling_factor: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+
+ Parameters
+ ----------
+ masked_kspace : torch.Tensor
+ Masked k-space of shape (N, coil, height, width, complex=2).
+ sampling_mask : torch.Tensor
+ Sampling mask of shape (N, 1, height, width, 1).
+ sensitivity_map : torch.Tensor
+ Sensitivity map of shape (N, coil, height, width, complex=2).
+ scaling_factor : Optional[torch.Tensor]
+ Scaling factor of shape (N,). If None, no scaling is applied. Default: None.
+
+ Returns
+ -------
+ out_image : torch.Tensor
+ Output image of shape (N, height, width, complex=2).
+ """
+ input_image = self._backward_operator(masked_kspace, sampling_mask, sensitivity_map)
+
+ if self.normalize_image and scaling_factor is not None:
+ input_image = input_image / scaling_factor ** 2
+ masked_kspace = masked_kspace / scaling_factor ** 2
+
+ image_buffer = torch.cat([input_image] * self.image_buffer_size, self._complex_dim).to(masked_kspace.device)
+
+ kspace_buffer = torch.cat([masked_kspace] * self.kspace_buffer_size, self._complex_dim).to(
+ masked_kspace.device
+ )
+
+ kspace_block_idx, image_block_idx = 0, 0
+ for block_domain in self.domain_sequence:
+ if block_domain == "K":
+ kspace_buffer = self.kspace_correction(
+ kspace_block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map, masked_kspace
+ )
+ kspace_block_idx += 1
+ else:
+ image_buffer = self.image_correction(
+ image_block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map
+ )
+ image_block_idx += 1
+
+ if self.normalize_image and scaling_factor is not None:
+ image_buffer = image_buffer * scaling_factor ** 2
+
+ out_image = image_buffer[..., :2]
+ return out_image
diff --git a/direct/nn/crossdomain/multicoil.py b/direct/nn/crossdomain/multicoil.py
new file mode 100644
index 00000000..f8e0fb37
--- /dev/null
+++ b/direct/nn/crossdomain/multicoil.py
@@ -0,0 +1,66 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import torch
+import torch.nn as nn
+
+
+class MultiCoil(nn.Module):
+ """
+ This makes the forward pass of multi-coil data of shape (N, N_coils, H, W, C) to a model. If coil_to_batch is set
+ to True, coil dimension is moved to the batch dimension. Otherwise, it passes to the model each coil-data
+ individually.
+ """
+
+ def __init__(self, model: nn.Module, coil_dim: int = 1, coil_to_batch: bool = False):
+ """
+
+ Parameters
+ ----------
+ model : nn.Module
+ Any nn.Module that takes as input with 4D data (N, H, W, C). Typically a convolutional-like model.
+ coil_dim : int
+ Coil dimension. Default: 1.
+ coil_to_batch : bool
+ If True batch and coil dimensions are merged when forwarded by the model and unmerged when outputted.
+ Otherwise, input is forwarded to the model per coil.
+ """
+ super().__init__()
+
+ self.model = model
+ self.coil_to_batch = coil_to_batch
+ self._coil_dim = coil_dim
+
+ def _compute_model_per_coil(self, data: torch.Tensor) -> torch.Tensor:
+ output = []
+
+ for idx in range(data.size(self._coil_dim)):
+ subselected_data = data.select(self._coil_dim, idx)
+ output.append(self.model(subselected_data))
+ output = torch.stack(output, dim=self._coil_dim)
+ return output
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Multi-coil input of shape (N, coil, height, width, in_channels).
+
+ Returns
+ -------
+ out : torch.Tensor
+ Multi-coil output of shape (N, coil, height, width, out_channels).
+ """
+ if self.coil_to_batch:
+ x = x.clone()
+ batch, coil, height, width, channels = x.size()
+
+ x = x.reshape(batch * coil, height, width, channels).permute(0, 3, 1, 2).contiguous()
+ x = self.model(x).permute(0, 2, 3, 1)
+ x = x.reshape(batch, coil, height, width, -1)
+ else:
+ x = self._compute_model_per_coil(x).contiguous()
+
+ return x
diff --git a/direct/nn/didn/__init__.py b/direct/nn/didn/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/didn/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/didn/didn.py b/direct/nn/didn/didn.py
new file mode 100644
index 00000000..91d5947e
--- /dev/null
+++ b/direct/nn/didn/didn.py
@@ -0,0 +1,263 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Subpixel(nn.Module):
+ """
+ Subpixel convolution layer for up-scaling of low resolution features at super-resolution as implemented
+ in https://ieeexplore.ieee.org/document/9025411.
+ """
+
+ def __init__(self, in_channels, out_channels, upscale_factor, kernel_size, padding=0):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels, out_channels * upscale_factor ** 2, kernel_size=kernel_size, padding=padding
+ )
+ self.pixelshuffle = nn.PixelShuffle(upscale_factor)
+
+ def forward(self, x):
+ return self.pixelshuffle(self.conv(x))
+
+
+class ReconBlock(nn.Module):
+ """
+ Reconstruction Block of DIDN model as implemented in https://ieeexplore.ieee.org/document/9025411.
+ """
+
+ def __init__(self, in_channels, num_convs):
+ super().__init__()
+ self.convs = nn.ModuleList(
+ [
+ nn.Sequential(
+ *[
+ nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1),
+ nn.PReLU(),
+ ]
+ )
+ for _ in range(num_convs - 1)
+ ]
+ )
+ self.convs.append(nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1))
+ self.num_convs = num_convs
+
+ def forward(self, input):
+
+ output = input.clone()
+ for idx in range(self.num_convs):
+ output = self.convs[idx](output)
+
+ return input + output
+
+
+class DUB(nn.Module):
+ """
+ Down-up block (DUB) for DIDN model as implemented in https://ieeexplore.ieee.org/document/9025411.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ # Scale 1
+ self.conv1_1 = nn.Sequential(*[nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), nn.PReLU()] * 2)
+ self.down1 = nn.Conv2d(in_channels, in_channels * 2, kernel_size=3, stride=2, padding=1)
+ # Scale 2
+ self.conv2_1 = nn.Sequential(
+ *[nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=3, padding=1), nn.PReLU()]
+ )
+ self.down2 = nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=3, stride=2, padding=1)
+ # Scale 3
+ self.conv3_1 = nn.Sequential(
+ *[
+ nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size=3, padding=1),
+ nn.PReLU(),
+ ]
+ )
+ self.up1 = nn.Sequential(
+ *[
+ # nn.Conv2d(in_channels * 4, in_channels * 8, kernel_size=1),
+ Subpixel(in_channels * 4, in_channels * 2, 2, 1, 0)
+ ]
+ )
+ # Scale 2
+ self.conv_agg_1 = nn.Conv2d(in_channels * 4, in_channels * 2, kernel_size=1)
+ self.conv2_2 = nn.Sequential(
+ *[
+ nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=3, padding=1),
+ nn.PReLU(),
+ ]
+ )
+ self.up2 = nn.Sequential(
+ *[
+ # nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1),
+ Subpixel(in_channels * 2, in_channels, 2, 1, 0)
+ ]
+ )
+ # Scale 1
+ self.conv_agg_2 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)
+ self.conv1_2 = nn.Sequential(*[nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), nn.PReLU()] * 2)
+ self.conv_out = nn.Sequential(*[nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), nn.PReLU()])
+
+ @staticmethod
+ def pad(x):
+ padding = [0, 0, 0, 0]
+
+ if x.shape[-2] % 2 != 0:
+ padding[3] = 1 # Padding right - width
+ if x.shape[-1] % 2 != 0:
+ padding[1] = 1 # Padding bottom - height
+ if sum(padding) != 0:
+ x = F.pad(x, padding, "reflect")
+ return x
+
+ @staticmethod
+ def crop_to_shape(x, shape):
+ h, w = x.shape[-2:]
+
+ if h > shape[0]:
+ x = x[:, :, : shape[0], :]
+ if w > shape[1]:
+ x = x[:, :, :, : shape[1]]
+ return x
+
+ def forward(self, x):
+ x1 = self.pad(x.clone())
+ x1 = x1 + self.conv1_1(x1)
+ x2 = self.down1(x1)
+ x2 = x2 + self.conv2_1(x2)
+ out = self.down2(x2)
+ out = out + self.conv3_1(out)
+ out = self.up1(out)
+ out = torch.cat([x2, self.crop_to_shape(out, x2.shape[-2:])], dim=1)
+ out = self.conv_agg_1(out)
+ out = out + self.conv2_2(out)
+ out = self.up2(out)
+ out = torch.cat([x1, self.crop_to_shape(out, x1.shape[-2:])], dim=1)
+ out = self.conv_agg_2(out)
+ out = out + self.conv1_2(out)
+ out = x + self.crop_to_shape(self.conv_out(out), x.shape[-2:])
+ return out
+
+
+class DIDN(nn.Module):
+ """
+ Deep Iterative Down-up convolutional Neural network (DIDN) implementation as in
+ https://ieeexplore.ieee.org/document/9025411.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ hidden_channels: int = 128,
+ num_dubs: int = 6,
+ num_convs_recon: int = 9,
+ skip_connection: bool = False,
+ ):
+ """
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels.
+ out_channels : int
+ Number of output channels.
+ hidden_channels : int
+ Number of hidden channels. First convolution out_channels. Default: 128.
+ num_dubs : int
+ Number of DUB networks. Default: 6.
+ num_convs_recon : int
+ Number of ReconBlock convolutions. Default: 9.
+ skip_connection : bool
+ Use skip connection. Default: False.
+ """
+ super().__init__()
+ self.conv_in = nn.Sequential(
+ *[nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=3, padding=1), nn.PReLU()]
+ )
+ self.down = nn.Conv2d(
+ in_channels=hidden_channels,
+ out_channels=hidden_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ )
+ self.dubs = nn.ModuleList(
+ [DUB(in_channels=hidden_channels, out_channels=hidden_channels) for _ in range(num_dubs)]
+ )
+ self.recon_block = ReconBlock(in_channels=hidden_channels, num_convs=num_convs_recon)
+ self.recon_agg = nn.Conv2d(in_channels=hidden_channels * num_dubs, out_channels=hidden_channels, kernel_size=1)
+ self.conv = nn.Sequential(
+ *[
+ nn.Conv2d(
+ in_channels=hidden_channels,
+ out_channels=hidden_channels,
+ kernel_size=3,
+ padding=1,
+ ),
+ nn.PReLU(),
+ ]
+ )
+ self.up2 = Subpixel(hidden_channels, hidden_channels, 2, 1)
+ self.conv_out = nn.Conv2d(
+ in_channels=hidden_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ padding=1,
+ )
+ self.num_dubs = num_dubs
+ self.skip_connection = (in_channels == out_channels) and skip_connection
+
+ @staticmethod
+ def crop_to_shape(x, shape):
+ h, w = x.shape[-2:]
+
+ if h > shape[0]:
+ x = x[:, :, : shape[0], :]
+ if w > shape[1]:
+ x = x[:, :, :, : shape[1]]
+ return x
+
+ def forward(self, x, channel_dim=1):
+ """
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor.
+ channel_dim : int
+ Channel dimension. Default: 1.
+
+ Returns
+ -------
+ out : torch.Tensor
+ Output tensor.
+ """
+ out = self.conv_in(x)
+ out = self.down(out)
+
+ dub_outs = []
+ for dub in self.dubs:
+ out = dub(out)
+ dub_outs.append(out)
+
+ out = [self.recon_block(dub_out) for dub_out in dub_outs]
+ out = self.recon_agg(torch.cat(out, dim=channel_dim))
+ out = self.conv(out)
+ out = self.up2(out)
+ out = self.conv_out(out)
+ out = self.crop_to_shape(out, x.shape[-2:])
+
+ if self.skip_connection:
+ out = x + out
+ return out
diff --git a/direct/nn/didn/tests/__init__.py b/direct/nn/didn/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/didn/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/didn/tests/test_didn.py b/direct/nn/didn/tests/test_didn.py
new file mode 100644
index 00000000..80c374d7
--- /dev/null
+++ b/direct/nn/didn/tests/test_didn.py
@@ -0,0 +1,51 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import pytest
+import torch
+
+from direct.nn.didn.didn import DIDN
+
+
+def create_input(shape):
+
+ data = torch.rand(shape).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [3, 2, 32, 32],
+ [3, 2, 16, 16],
+ ],
+)
+@pytest.mark.parametrize(
+ "out_channels",
+ [3, 5],
+)
+@pytest.mark.parametrize(
+ "hidden_channels",
+ [16, 8],
+)
+@pytest.mark.parametrize(
+ "n_dubs",
+ [3, 4],
+)
+@pytest.mark.parametrize(
+ "num_convs_recon",
+ [3, 4],
+)
+@pytest.mark.parametrize(
+ "skip",
+ [True, False],
+)
+def test_didn(shape, out_channels, hidden_channels, n_dubs, num_convs_recon, skip):
+ model = DIDN(shape[1], out_channels, hidden_channels, n_dubs, num_convs_recon, skip)
+
+ data = create_input(shape).cpu()
+
+ out = model(data)
+
+ assert list(out.shape) == [shape[0]] + [out_channels] + shape[2:]
diff --git a/direct/nn/jointicnet/__init__.py b/direct/nn/jointicnet/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/jointicnet/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/jointicnet/config.py b/direct/nn/jointicnet/config.py
new file mode 100644
index 00000000..ad9184f8
--- /dev/null
+++ b/direct/nn/jointicnet/config.py
@@ -0,0 +1,21 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from dataclasses import dataclass
+
+from direct.config.defaults import ModelConfig
+
+
+@dataclass
+class JointICNetConfig(ModelConfig):
+ num_iter: int = 10
+ use_norm_unet: bool = False
+ image_unet_num_filters: int = 8
+ image_unet_num_pool_layers: int = 4
+ image_unet_dropout: float = 0.0
+ kspace_unet_num_filters: int = 8
+ kspace_unet_num_pool_layers: int = 4
+ kspace_unet_dropout: float = 0.0
+ sens_unet_num_filters: int = 8
+ sens_unet_num_pool_layers: int = 4
+ sens_unet_dropout: float = 0.0
diff --git a/direct/nn/jointicnet/jointicnet.py b/direct/nn/jointicnet/jointicnet.py
new file mode 100644
index 00000000..cd52f3bf
--- /dev/null
+++ b/direct/nn/jointicnet/jointicnet.py
@@ -0,0 +1,214 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from typing import Callable
+
+import torch
+import torch.nn as nn
+
+import direct.data.transforms as T
+from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d
+
+
+class JointICNet(nn.Module):
+ """
+ Joint-ICNet implementation as in "Joint Deep Model-based MR Image and Coil Sensitivity Reconstruction Network
+ (Joint-ICNet) for Fast MRI" submitted to the fastmri challenge.
+
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable,
+ backward_operator: Callable,
+ num_iter: int = 10,
+ use_norm_unet: bool = False,
+ **kwargs,
+ ):
+ """
+
+ Parameters
+ ----------
+ forward_operator : Callable
+ Forward Transform.
+ backward_operator : Callable
+ Backward Transform.
+ num_iter : int
+ Number of unrolled iterations. Default: 10.
+ use_norm_unet : bool
+ If True, a Normalized U-Net is used. Default: False.
+ kwargs: dict
+ Image, k-space and sensitivity-map U-Net models keyword-arguments.
+ """
+ super().__init__()
+
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+ self.num_iter = num_iter
+
+ unet_architecture = NormUnetModel2d if use_norm_unet else UnetModel2d
+
+ self.image_model = unet_architecture(
+ in_channels=2,
+ out_channels=2,
+ num_filters=kwargs.get("image_unet_num_filters", 8),
+ num_pool_layers=kwargs.get("image_unet_num_pool_layers", 4),
+ dropout_probability=kwargs.get("image_unet_dropout", 0.0),
+ )
+ self.kspace_model = unet_architecture(
+ in_channels=2,
+ out_channels=2,
+ num_filters=kwargs.get("kspace_unet_num_filters", 8),
+ num_pool_layers=kwargs.get("kspace_unet_num_pool_layers", 4),
+ dropout_probability=kwargs.get("kspace_unet_dropout", 0.0),
+ )
+ self.sens_model = unet_architecture(
+ in_channels=2,
+ out_channels=2,
+ num_filters=kwargs.get("sens_unet_num_filters", 8),
+ num_pool_layers=kwargs.get("sens_unet_num_pool_layers", 4),
+ dropout_probability=kwargs.get("sens_unet_dropout", 0.0),
+ )
+ self.conv_out = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=1)
+
+ self.reg_param_I = nn.Parameter(torch.ones(num_iter))
+ self.reg_param_F = nn.Parameter(torch.ones(num_iter))
+ self.reg_param_C = nn.Parameter(torch.ones(num_iter))
+
+ self.lr_image = nn.Parameter(torch.ones(num_iter))
+ self.lr_sens = nn.Parameter(torch.ones(num_iter))
+
+ self._coil_dim = 1
+ self._complex_dim = -1
+ self._spatial_dims = (2, 3)
+
+ def _image_model(self, image):
+ image = image.permute(0, 3, 1, 2)
+ return self.image_model(image).permute(0, 2, 3, 1).contiguous()
+
+ def _kspace_model(self, kspace):
+ kspace = kspace.permute(0, 3, 1, 2)
+ return self.kspace_model(kspace).permute(0, 2, 3, 1).contiguous()
+
+ def _sens_model(self, sensitivity_map):
+ return (
+ self._compute_model_per_coil(self.sens_model, sensitivity_map.permute(0, 1, 4, 2, 3))
+ .permute(0, 1, 3, 4, 2)
+ .contiguous()
+ )
+
+ def _compute_model_per_coil(self, model, data):
+ output = []
+ for idx in range(data.size(self._coil_dim)):
+ subselected_data = data.select(self._coil_dim, idx)
+ output.append(model(subselected_data))
+ output = torch.stack(output, dim=self._coil_dim)
+ return output
+
+ def _forward_operator(self, image, sampling_mask, sensitivity_map):
+ forward = torch.where(
+ sampling_mask == 0,
+ torch.tensor([0.0], dtype=image.dtype).to(image.device),
+ self.forward_operator(T.expand_operator(image, sensitivity_map, self._coil_dim), dim=self._spatial_dims),
+ )
+ return forward
+
+ def _backward_operator(self, kspace, sampling_mask, sensitivity_map):
+ backward = T.reduce_operator(
+ self.backward_operator(
+ torch.where(
+ sampling_mask == 0,
+ torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device),
+ kspace,
+ ),
+ self._spatial_dims,
+ ),
+ sensitivity_map,
+ self._coil_dim,
+ )
+ return backward
+
+ def forward(
+ self,
+ masked_kspace: torch.Tensor,
+ sampling_mask: torch.Tensor,
+ sensitivity_map: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+
+ Parameters
+ ----------
+ masked_kspace : torch.Tensor
+ Masked k-space of shape (N, coil, height, width, complex=2).
+ sampling_mask : torch.Tensor
+ Sampling mask of shape (N, 1, height, width, 1).
+ sensitivity_map : torch.Tensor
+ Sensitivity map of shape (N, coil, height, width, complex=2).
+
+ Returns
+ -------
+ out_image : torch.Tensor
+ Output image of shape (N, height, width, complex=2).
+ """
+
+ input_image = self._backward_operator(masked_kspace, sampling_mask, sensitivity_map)
+ input_image = input_image / T.modulus(input_image).unsqueeze(self._coil_dim).amax(dim=self._spatial_dims).view(
+ -1, 1, 1, 1
+ )
+
+ for iter in range(self.num_iter):
+ step_sensitivity_map = (
+ 2
+ * self.lr_sens[iter]
+ * (
+ T.complex_multiplication(
+ self.backward_operator(
+ torch.where(
+ sampling_mask == 0,
+ torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
+ self._forward_operator(input_image, sampling_mask, sensitivity_map) - masked_kspace,
+ ),
+ self._spatial_dims,
+ ),
+ T.conjugate(input_image).unsqueeze(self._coil_dim),
+ )
+ + self.reg_param_C[iter]
+ * (
+ sensitivity_map
+ - self._sens_model(self.backward_operator(masked_kspace, dim=self._spatial_dims))
+ )
+ )
+ )
+ sensitivity_map = sensitivity_map - step_sensitivity_map
+ sensitivity_map_norm = torch.sqrt(((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim))
+ sensitivity_map_norm = sensitivity_map_norm.unsqueeze(self._complex_dim).unsqueeze(self._coil_dim)
+ sensitivity_map = T.safe_divide(sensitivity_map, sensitivity_map_norm)
+ input_kspace = self.forward_operator(input_image, dim=tuple([d - 1 for d in self._spatial_dims]))
+
+ step_image = (
+ 2
+ * self.lr_image[iter]
+ * (
+ self._backward_operator(
+ self._forward_operator(input_image, sampling_mask, sensitivity_map) - masked_kspace,
+ sampling_mask,
+ sensitivity_map,
+ )
+ + self.reg_param_I[iter] * (input_image - self._image_model(input_image))
+ + self.reg_param_F[iter]
+ * (
+ input_image
+ - self.backward_operator(
+ self._kspace_model(input_kspace), dim=tuple([d - 1 for d in self._spatial_dims])
+ )
+ )
+ )
+ )
+
+ input_image = input_image - step_image
+ input_image = input_image / T.modulus(input_image).unsqueeze(self._coil_dim).amax(
+ dim=self._spatial_dims
+ ).view(-1, 1, 1, 1)
+
+ out_image = self.conv_out(input_image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+ return out_image
diff --git a/direct/nn/jointicnet/jointicnet_engine.py b/direct/nn/jointicnet/jointicnet_engine.py
new file mode 100644
index 00000000..ae939983
--- /dev/null
+++ b/direct/nn/jointicnet/jointicnet_engine.py
@@ -0,0 +1,445 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import time
+from collections import defaultdict
+from os import PathLike
+from typing import Callable, DefaultDict, Dict, List, Optional
+
+import numpy as np
+import torch
+from torch import nn
+from torch.cuda.amp import autocast
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+
+import direct.data.transforms as T
+from direct.config import BaseConfig
+from direct.engine import DoIterationOutput, Engine
+from direct.functionals import SSIMLoss
+from direct.utils import (
+ communication,
+ detach_dict,
+ dict_to_device,
+ merge_list_of_dicts,
+ multiply_function,
+ reduce_list_of_dicts,
+)
+from direct.utils.communication import reduce_tensor_dict
+
+
+class JointICNetEngine(Engine):
+ """
+ Joint-ICNet Engine.
+ """
+
+ def __init__(
+ self,
+ cfg: BaseConfig,
+ model: nn.Module,
+ device: int,
+ forward_operator: Optional[Callable] = None,
+ backward_operator: Optional[Callable] = None,
+ mixed_precision: bool = False,
+ **models: nn.Module,
+ ):
+ super().__init__(
+ cfg,
+ model,
+ device,
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ mixed_precision=mixed_precision,
+ **models,
+ )
+
+ self._coil_dim = 1
+ self._complex_dim = -1
+ self._spatial_dims = (2, 3)
+
+ def _do_iteration(
+ self,
+ data: Dict[str, torch.Tensor],
+ loss_fns: Optional[Dict[str, Callable]] = None,
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ ) -> DoIterationOutput:
+
+ # loss_fns can be done, e.g. during validation
+ if loss_fns is None:
+ loss_fns = {}
+
+ if regularizer_fns is None:
+ regularizer_fns = {}
+
+ loss_dicts = []
+ regularizer_dicts = []
+
+ data = dict_to_device(data, self.device)
+
+ # sensitivity_map of shape (batch, coil, height, width, complex=2)
+ sensitivity_map = data["sensitivity_map"]
+
+ # The sensitivity map needs to be normalized such that
+ # So \sum_{i \in \text{coils}} S_i S_i^* = 1
+
+ sensitivity_map_norm = torch.sqrt(
+ ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim)
+ ) # shape (batch, height, width)
+ sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1)
+ data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm)
+
+ with autocast(enabled=self.mixed_precision):
+
+ output_image = self.model(
+ masked_kspace=data["masked_kspace"],
+ sampling_mask=data["sampling_mask"],
+ sensitivity_map=data["sensitivity_map"],
+ ) # shape (batch, height, width, complex=2)
+
+ output_image = T.modulus(output_image) # shape (batch, height, width)
+
+ loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()}
+ regularizer_dict = {
+ k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys()
+ }
+
+ for key, value in loss_dict.items():
+ loss_dict[key] = value + loss_fns[key](
+ output_image,
+ **data,
+ reduction="mean",
+ )
+
+ for key, value in regularizer_dict.items():
+ regularizer_dict[key] = value + regularizer_fns[key](
+ output_image,
+ **data,
+ )
+
+ loss = sum(loss_dict.values()) + sum(regularizer_dict.values())
+
+ if self.model.training:
+ self._scaler.scale(loss).backward()
+
+ loss_dicts.append(detach_dict(loss_dict))
+ regularizer_dicts.append(
+ detach_dict(regularizer_dict)
+ ) # Need to detach dict as this is only used for logging.
+
+ # Add the loss dicts.
+ loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum")
+ regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum")
+
+ return DoIterationOutput(
+ output_image=output_image,
+ sensitivity_map=data["sensitivity_map"],
+ data_dict={**loss_dict, **regularizer_dict},
+ )
+
+ def build_loss(self, **kwargs) -> Dict:
+ # TODO: Cropper is a processing output tool.
+ def get_resolution(**data):
+ """Be careful that this will use the cropping size of the FIRST sample in the batch."""
+ return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None))
+
+ def l1_loss(source, reduction="mean", **data):
+ """
+ Calculate L1 loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l1_loss
+
+ def l2_loss(source, reduction="mean", **data):
+ """
+ Calculate L2 loss (MSE) given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l2_loss
+
+ def ssim_loss(source, reduction="mean", **data):
+ """
+ Calculate SSIM loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ if reduction != "mean":
+ raise AssertionError(
+ f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}."
+ )
+
+ source_abs, target_abs = self.cropper(source, data["target"], resolution)
+ data_range = torch.tensor([target_abs.max()], device=target_abs.device)
+
+ ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range)
+
+ return ssim_loss
+
+ # Build losses
+ loss_dict = {}
+ for curr_loss in self.cfg.training.loss.losses: # type: ignore
+ loss_fn = curr_loss.function
+ if loss_fn == "l1_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss)
+ elif loss_fn == "l2_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss)
+ elif loss_fn == "ssim_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss)
+ else:
+ raise ValueError(f"{loss_fn} not permissible.")
+
+ return loss_dict
+
+ @torch.no_grad()
+ def evaluate(
+ self,
+ data_loader: DataLoader,
+ loss_fns: Optional[Dict[str, Callable]],
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ crop: Optional[str] = None,
+ is_validation_process: bool = True,
+ ):
+ """
+ Validation process. Assumes that each batch only contains slices of the same volume *AND* that these
+ are sequentially ordered.
+
+ Parameters
+ ----------
+ data_loader : DataLoader
+ loss_fns : Dict[str, Callable], optional
+ regularizer_fns : Dict[str, Callable], optional
+ crop : str, optional
+ is_validation_process : bool
+
+ Returns
+ -------
+ loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ """
+ self.models_to_device()
+ self.models_validation_mode()
+ torch.cuda.empty_cache()
+
+ # Variables required for evaluation.
+ volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore
+
+ # filenames can be in the volume_indices attribute of the dataset
+ num_for_this_process = None
+ all_filenames = None
+ if hasattr(data_loader.dataset, "volume_indices"):
+ all_filenames = list(data_loader.dataset.volume_indices.keys())
+ num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys()))
+ self.logger.info(
+ f"Reconstructing a total of {len(all_filenames)} volumes. "
+ f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})."
+ )
+
+ filenames_seen = 0
+ reconstruction_output: DefaultDict = defaultdict(list)
+ if is_validation_process:
+ targets_output: DefaultDict = defaultdict(list)
+ val_losses = []
+ val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict)
+ last_filename = None
+
+ # Container to for the slices which can be visualized in TensorBoard.
+ visualize_slices: List[np.ndarray] = []
+ visualize_target: List[np.ndarray] = []
+ # visualizations = {}
+
+ extra_visualization_keys = (
+ self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore
+ )
+
+ # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
+ # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
+ # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only
+ # contains data from one volume.
+ time_start = time.time()
+
+ for iter_idx, data in enumerate(data_loader):
+ filenames = data.pop("filename")
+ if len(set(filenames)) != 1:
+ raise ValueError(
+ f"Expected a batch during validation to only contain filenames of one case. "
+ f"Got {set(filenames)}."
+ )
+
+ slice_nos = data.pop("slice_no")
+ scaling_factors = data["scaling_factor"]
+
+ resolution = self.compute_resolution(
+ key=self.cfg.validation.crop, # type: ignore
+ reconstruction_size=data.get("reconstruction_size", None),
+ )
+
+ # Compute output and loss.
+ iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns)
+ output = iteration_output.output_image
+ loss_dict = iteration_output.data_dict
+
+ loss_dict = detach_dict(loss_dict)
+ output = output.detach()
+ val_losses.append(loss_dict)
+
+ # Output is complex-valued, and has to be cropped. This holds for both output and target.
+ # Output has shape (batch, complex, height, width)
+ output_abs = self.process_output(
+ output,
+ scaling_factors,
+ resolution=resolution,
+ )
+
+ if is_validation_process:
+ # Target has shape (batch, height, width)
+ target_abs = self.process_output(
+ data["target"].detach(),
+ scaling_factors,
+ resolution=resolution,
+ )
+ for key in extra_visualization_keys:
+ curr_data = data[key].detach()
+ # Here we need to discover which keys are actually normalized or not
+ # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23
+
+ del output # Explicitly call delete to clear memory.
+
+ # Aggregate volumes to be able to compute the metrics on complete volumes.
+ for idx, filename in enumerate(filenames):
+ if last_filename is None:
+ last_filename = filename # First iteration last_filename is not set.
+
+ curr_slice = output_abs[idx].detach()
+ slice_no = int(slice_nos[idx].numpy())
+
+ reconstruction_output[filename].append((slice_no, curr_slice.cpu()))
+
+ if is_validation_process:
+ targets_output[filename].append((slice_no, target_abs[idx].cpu()))
+
+ is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"])
+ reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch]
+ for condition in reconstruction_conditions:
+ if condition:
+ filenames_seen += 1
+
+ # Now we can ditch the reconstruction dict by reconstructing the volume,
+ # will take too much memory otherwise.
+ volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]])
+ if is_validation_process:
+ target = torch.stack([_[1] for _ in targets_output[last_filename]])
+ curr_metrics = {
+ metric_name: metric_fn(target, volume)
+ for metric_name, metric_fn in volume_metrics.items()
+ }
+ val_volume_metrics[last_filename] = curr_metrics
+ # Log the center slice of the volume
+ if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore
+ visualize_slices.append(volume[volume.shape[0] // 2])
+ visualize_target.append(target[target.shape[0] // 2])
+
+ # Delete outputs from memory, and recreate dictionary.
+ # This is not needed when not in validation as we are actually interested
+ # in the iteration output.
+ del targets_output[last_filename]
+ del reconstruction_output[last_filename]
+
+ if all_filenames:
+ log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
+ else:
+ log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"
+
+ self.logger.info(
+ f"{log_prefix} {last_filename}"
+ f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
+ )
+ # restart timer
+ time_start = time.time()
+ last_filename = filename
+
+ # Average loss dict
+ loss_dict = reduce_list_of_dicts(val_losses)
+ reduce_tensor_dict(loss_dict)
+
+ communication.synchronize()
+ torch.cuda.empty_cache()
+
+ all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics))
+ if not is_validation_process:
+ return loss_dict, reconstruction_output
+
+ return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ def process_output(self, data, scaling_factors=None, resolution=None):
+ # data is of shape (batch, complex=2, height, width)
+ if scaling_factors is not None:
+ data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device)
+
+ data = T.modulus_if_complex(data)
+
+ if len(data.shape) == 3: # (batch, height, width)
+ data = data.unsqueeze(1) # Added channel dimension.
+
+ if resolution is not None:
+ data = T.center_crop(data, resolution).contiguous()
+
+ return data
+
+ @staticmethod
+ def compute_resolution(key, reconstruction_size):
+ if key == "header":
+ # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over
+ # batches.
+ resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size]
+ # The volume sampler should give validation indices belonging to the *same* volume, so it should be
+ # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y).
+ resolution = [_[0] for _ in resolution][:-1]
+ elif key == "training":
+ resolution = key
+ elif not key:
+ resolution = None
+ else:
+ raise ValueError(
+ "Cropping should be either set to `header` to get the values from the header or "
+ "`training` to take the same value as training."
+ )
+ return resolution
+
+ def cropper(self, source, target, resolution):
+ """
+ 2D source/target cropper
+
+ Parameters:
+ -----------
+ Source has shape (batch, height, width)
+ Target has shape (batch, height, width)
+
+ """
+
+ if not resolution or all(_ == 0 for _ in resolution):
+ return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension.
+
+ source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension.
+ target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension.
+
+ return source_abs, target_abs
diff --git a/direct/nn/jointicnet/tests/__init__.py b/direct/nn/jointicnet/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/jointicnet/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/jointicnet/tests/test_jointicnet.py b/direct/nn/jointicnet/tests/test_jointicnet.py
new file mode 100644
index 00000000..f68fb3cf
--- /dev/null
+++ b/direct/nn/jointicnet/tests/test_jointicnet.py
@@ -0,0 +1,50 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import pytest
+import torch
+
+from direct.data.transforms import fft2, ifft2
+from direct.nn.jointicnet.jointicnet import JointICNet
+
+
+def create_input(shape):
+
+ data = torch.rand(shape).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [3, 3, 16, 16],
+ [2, 5, 16, 32],
+ ],
+)
+@pytest.mark.parametrize(
+ "num_iter",
+ [2, 4],
+)
+@pytest.mark.parametrize(
+ "use_norm_unet",
+ [True, False],
+)
+def test_jointicnet(shape, num_iter, use_norm_unet):
+ model = JointICNet(
+ fft2,
+ ifft2,
+ num_iter,
+ use_norm_unet,
+ image_unet_num_pool_layers=2,
+ kspace_unet_num_pool_layers=2,
+ sens_unet_num_pool_layers=2,
+ ).cpu()
+
+ kspace = create_input(shape + [2]).cpu()
+ mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu()
+ sens = create_input(shape + [2]).cpu()
+
+ out = model(kspace, mask, sens)
+
+ assert list(out.shape) == [shape[0]] + shape[2:] + [2]
diff --git a/direct/nn/kikinet/__init__.py b/direct/nn/kikinet/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/kikinet/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/kikinet/config.py b/direct/nn/kikinet/config.py
new file mode 100644
index 00000000..4f54fb67
--- /dev/null
+++ b/direct/nn/kikinet/config.py
@@ -0,0 +1,30 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from dataclasses import dataclass
+
+from direct.config.defaults import ModelConfig
+
+
+@dataclass
+class KIKINetConfig(ModelConfig):
+ num_iter: int = 10
+ image_model_architecture: str = "MWCNN"
+ kspace_model_architecture: str = "UNET"
+ image_mwcnn_hidden_channels: int = 16
+ image_mwcnn_num_scales: int = 4
+ image_mwcnn_bias: bool = True
+ image_mwcnn_batchnorm: bool = False
+ image_unet_num_filters: int = 8
+ image_unet_num_pool_layers: int = 4
+ image_unet_dropout_probability: float = 0.0
+ kspace_conv_hidden_channels: int = 16
+ kspace_conv_n_convs: int = 4
+ kspace_conv_batchnorm: bool = False
+ kspace_didn_hidden_channels: int = 64
+ kspace_didn_num_dubs: int = 6
+ kspace_didn_num_convs_recon: int = 9
+ kspace_unet_num_filters: int = 8
+ kspace_unet_num_pool_layers: int = 4
+ kspace_unet_dropout_probability: float = 0.0
+ normalize: bool = False
diff --git a/direct/nn/kikinet/kikinet.py b/direct/nn/kikinet/kikinet.py
new file mode 100644
index 00000000..712d74ce
--- /dev/null
+++ b/direct/nn/kikinet/kikinet.py
@@ -0,0 +1,181 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+
+import direct.data.transforms as T
+from direct.nn.conv.conv import Conv2d
+from direct.nn.didn.didn import DIDN
+from direct.nn.mwcnn.mwcnn import MWCNN
+from direct.nn.crossdomain.multicoil import MultiCoil
+from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d
+
+
+class KIKINet(nn.Module):
+ """
+ Based on KIKINet implementation as in "KIKI-net: cross-domain convolutional neural networks for
+ reconstructing undersampled magnetic resonance images" by Taejoon Eo et all. Modified to work with
+ multicoil kspace data.
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable,
+ backward_operator: Callable,
+ image_model_architecture: str = "MWCNN",
+ kspace_model_architecture: str = "DIDN",
+ num_iter: int = 2,
+ normalize: bool = False,
+ **kwargs,
+ ):
+ """
+
+ Parameters
+ ----------
+ forward_operator : Callable
+ Forward Operator.
+ backward_operator : Callable
+ Backward Operator.
+ image_model_architecture : str
+ Image model architecture. Currently only implemented for MWCNN and (NORM)UNET. Default: 'MWCNN'.
+ kspace_model_architecture : str
+ Kspace model architecture. Currently only implemented for CONV and DIDN and (NORM)UNET. Default: 'DIDN'.
+ num_iter : int
+ Number of unrolled iterations.
+ normalize : bool
+ If true, input is normalised based on input scaling_factor.
+ kwargs : dict
+ Keyword arguments for model architectures.
+ """
+ super().__init__()
+
+ if image_model_architecture == "MWCNN":
+ image_model = MWCNN(
+ input_channels=2,
+ first_conv_hidden_channels=kwargs.get("image_mwcnn_hidden_channels", 32),
+ num_scales=kwargs.get("image_mwcnn_num_scales", 4),
+ bias=kwargs.get("image_mwcnn_bias", False),
+ batchnorm=kwargs.get("image_mwcnn_batchnorm", False),
+ )
+ elif image_model_architecture in ["UNET", "NORMUNET"]:
+ unet = UnetModel2d if image_model_architecture == "UNET" else NormUnetModel2d
+ image_model = unet(
+ in_channels=2,
+ out_channels=2,
+ num_filters=kwargs.get("image_unet_num_filters", 8),
+ num_pool_layers=kwargs.get("image_unet_num_pool_layers", 4),
+ dropout_probability=kwargs.get("image_unet_dropout_probability", 0.0),
+ )
+ else:
+ raise NotImplementedError(
+ f"XPDNet is currently implemented only with image_model_architecture == 'MWCNN', 'UNET' or 'NORMUNET."
+ f"Got {image_model_architecture}."
+ )
+
+ if kspace_model_architecture == "CONV":
+ kspace_model = Conv2d(
+ in_channels=2,
+ out_channels=2,
+ hidden_channels=kwargs.get("kspace_conv_hidden_channels", 16),
+ n_convs=kwargs.get("kspace_conv_n_convs", 4),
+ batchnorm=kwargs.get("kspace_conv_batchnorm", False),
+ )
+ elif kspace_model_architecture == "DIDN":
+ kspace_model = DIDN(
+ in_channels=2,
+ out_channels=2,
+ hidden_channels=kwargs.get("kspace_didn_hidden_channels", 16),
+ num_dubs=kwargs.get("kspace_didn_num_dubs", 6),
+ num_convs_recon=kwargs.get("kspace_didn_num_convs_recon", 9),
+ )
+ elif kspace_model_architecture in ["UNET", "NORMUNET"]:
+ unet = UnetModel2d if kspace_model_architecture == "UNET" else NormUnetModel2d
+ kspace_model = unet(
+ in_channels=2,
+ out_channels=2,
+ num_filters=kwargs.get("kspace_unet_num_filters", 8),
+ num_pool_layers=kwargs.get("kspace_unet_num_pool_layers", 4),
+ dropout_probability=kwargs.get("kspace_unet_dropout_probability", 0.0),
+ )
+ else:
+ raise NotImplementedError(
+ f"XPDNet is currently implemented for kspace_model_architecture == 'CONV', 'DIDN',"
+ f" 'UNET' or 'NORMUNET'. Got kspace_model_architecture == {kspace_model_architecture}."
+ )
+
+ self._coil_dim = 1
+ self._complex_dim = -1
+ self._spatial_dims = (2, 3)
+
+ self.image_model_list = nn.ModuleList([image_model] * num_iter)
+ self.kspace_model_list = nn.ModuleList([MultiCoil(kspace_model, self._coil_dim)] * num_iter)
+
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+ self.num_iter = num_iter
+ self.normalize = normalize
+
+ def forward(
+ self,
+ masked_kspace: torch.Tensor,
+ sampling_mask: torch.Tensor,
+ sensitivity_map: torch.Tensor,
+ scaling_factor: Optional[torch.Tensor] = None,
+ ):
+ """
+
+ Parameters
+ ----------
+ masked_kspace : torch.Tensor
+ Masked k-space of shape (N, coil, height, width, complex=2).
+ sampling_mask : torch.Tensor
+ Sampling mask of shape (N, 1, height, width, 1).
+ sensitivity_map : torch.Tensor
+ Sensitivity map of shape (N, coil, height, width, complex=2).
+ scaling_factor : Optional[torch.Tensor]
+ Scaling factor of shape (N,). If None, no scaling is applied. Default: None.
+
+ Returns
+ -------
+ out_image : torch.Tensor
+ Output image of shape (N, height, width, complex=2).
+ """
+
+ kspace = masked_kspace.clone()
+ if self.normalize and scaling_factor is not None:
+ kspace = kspace / (scaling_factor ** 2).view(-1, 1, 1, 1, 1)
+
+ for idx in range(self.num_iter):
+ kspace = self.kspace_model_list[idx](kspace.permute(0, 1, 4, 2, 3)).permute(0, 1, 3, 4, 2)
+
+ image = T.reduce_operator(
+ self.backward_operator(
+ torch.where(
+ sampling_mask == 0,
+ torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device),
+ kspace,
+ ),
+ self._spatial_dims,
+ ),
+ sensitivity_map,
+ self._coil_dim,
+ )
+
+ image = self.image_model_list[idx](image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+
+ if idx < self.num_iter - 1:
+ kspace = torch.where(
+ sampling_mask == 0,
+ torch.tensor([0.0], dtype=image.dtype).to(image.device),
+ self.forward_operator(
+ T.expand_operator(image, sensitivity_map, self._coil_dim), dim=self._spatial_dims
+ ),
+ )
+
+ if self.normalize and scaling_factor is not None:
+ image = image * (scaling_factor ** 2).view(-1, 1, 1, 1)
+
+ return image
diff --git a/direct/nn/kikinet/kikinet_engine.py b/direct/nn/kikinet/kikinet_engine.py
new file mode 100644
index 00000000..976524f4
--- /dev/null
+++ b/direct/nn/kikinet/kikinet_engine.py
@@ -0,0 +1,470 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import time
+from collections import defaultdict
+from os import PathLike
+from typing import Callable, DefaultDict, Dict, List, Optional
+
+import numpy as np
+import torch
+from torch import nn
+from torch.cuda.amp import autocast
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+
+import direct.data.transforms as T
+from direct.config import BaseConfig
+from direct.engine import DoIterationOutput, Engine
+from direct.functionals import SSIMLoss
+from direct.utils import (
+ communication,
+ detach_dict,
+ dict_to_device,
+ merge_list_of_dicts,
+ multiply_function,
+ reduce_list_of_dicts,
+)
+from direct.utils.communication import reduce_tensor_dict
+
+
+class KIKINetEngine(Engine):
+ """
+ XPDNet Engine.
+ """
+
+ def __init__(
+ self,
+ cfg: BaseConfig,
+ model: nn.Module,
+ device: int,
+ forward_operator: Optional[Callable] = None,
+ backward_operator: Optional[Callable] = None,
+ mixed_precision: bool = False,
+ **models: nn.Module,
+ ):
+ super().__init__(
+ cfg,
+ model,
+ device,
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ mixed_precision=mixed_precision,
+ **models,
+ )
+
+ self._complex_dim = -1
+ self._coil_dim = 1
+ self._spatial_dims = (2, 3)
+
+ def _do_iteration(
+ self,
+ data: Dict[str, torch.Tensor],
+ loss_fns: Optional[Dict[str, Callable]] = None,
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ ) -> DoIterationOutput:
+
+ # loss_fns can be done, e.g. during validation
+ if loss_fns is None:
+ loss_fns = {}
+
+ if regularizer_fns is None:
+ regularizer_fns = {}
+
+ loss_dicts = []
+ regularizer_dicts = []
+
+ data = dict_to_device(data, self.device)
+
+ # sensitivity_map of shape (batch, coil, height, width, complex=2)
+ sensitivity_map = data["sensitivity_map"]
+
+ if "sensitivity_model" in self.models:
+
+ # Move channels to first axis
+ sensitivity_map = data["sensitivity_map"].permute((0, 1, 4, 2, 3))
+
+ sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute(
+ (0, 1, 3, 4, 2)
+ )
+
+ # The sensitivity map needs to be normalized such that
+ # So \sum_{i \in \text{coils}} S_i S_i^* = 1
+
+ sensitivity_map_norm = torch.sqrt(
+ ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim)
+ ) # shape (batch, height, width)
+ sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1)
+ data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm)
+
+ with autocast(enabled=self.mixed_precision):
+
+ output_image = self.model(
+ masked_kspace=data["masked_kspace"],
+ sampling_mask=data["sampling_mask"],
+ sensitivity_map=data["sensitivity_map"],
+ scaling_factor=data["scaling_factor"],
+ ) # shape (batch, height, width, complex=2)
+
+ output_image = T.modulus(output_image) # shape (batch, height, width)
+
+ loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()}
+ regularizer_dict = {
+ k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys()
+ }
+
+ for key, value in loss_dict.items():
+ loss_dict[key] = value + loss_fns[key](
+ output_image,
+ **data,
+ reduction="mean",
+ )
+
+ for key, value in regularizer_dict.items():
+ regularizer_dict[key] = value + regularizer_fns[key](
+ output_image,
+ **data,
+ )
+
+ loss = sum(loss_dict.values()) + sum(regularizer_dict.values())
+
+ if self.model.training:
+ self._scaler.scale(loss).backward()
+
+ loss_dicts.append(detach_dict(loss_dict))
+ regularizer_dicts.append(
+ detach_dict(regularizer_dict)
+ ) # Need to detach dict as this is only used for logging.
+
+ # Add the loss dicts.
+ loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum")
+ regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum")
+
+ return DoIterationOutput(
+ output_image=output_image,
+ sensitivity_map=data["sensitivity_map"],
+ data_dict={**loss_dict, **regularizer_dict},
+ )
+
+ def build_loss(self, **kwargs) -> Dict:
+ # TODO: Cropper is a processing output tool.
+ def get_resolution(**data):
+ """Be careful that this will use the cropping size of the FIRST sample in the batch."""
+ return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None))
+
+ def l1_loss(source, reduction="mean", **data):
+ """
+ Calculate L1 loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l1_loss
+
+ def l2_loss(source, reduction="mean", **data):
+ """
+ Calculate L2 loss (MSE) given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l2_loss
+
+ def ssim_loss(source, reduction="mean", **data):
+ """
+ Calculate SSIM loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ if reduction != "mean":
+ raise AssertionError(
+ f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}."
+ )
+
+ source_abs, target_abs = self.cropper(source, data["target"], resolution)
+ data_range = torch.tensor([target_abs.max()], device=target_abs.device)
+
+ ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range)
+
+ return ssim_loss
+
+ # Build losses
+ loss_dict = {}
+ for curr_loss in self.cfg.training.loss.losses: # type: ignore
+ loss_fn = curr_loss.function
+ if loss_fn == "l1_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss)
+ elif loss_fn == "l2_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss)
+ elif loss_fn == "ssim_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss)
+ else:
+ raise ValueError(f"{loss_fn} not permissible.")
+
+ return loss_dict
+
+ @torch.no_grad()
+ def evaluate(
+ self,
+ data_loader: DataLoader,
+ loss_fns: Optional[Dict[str, Callable]],
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ crop: Optional[str] = None,
+ is_validation_process: bool = True,
+ ):
+ """
+ Validation process. Assumes that each batch only contains slices of the same volume *AND* that these
+ are sequentially ordered.
+
+ Parameters
+ ----------
+ data_loader : DataLoader
+ loss_fns : Dict[str, Callable], optional
+ regularizer_fns : Dict[str, Callable], optional
+ crop : str, optional
+ is_validation_process : bool
+
+ Returns
+ -------
+ loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ """
+ self.models_to_device()
+ self.models_validation_mode()
+ torch.cuda.empty_cache()
+
+ # Variables required for evaluation.
+ volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore
+
+ # filenames can be in the volume_indices attribute of the dataset
+ num_for_this_process = None
+ all_filenames = None
+ if hasattr(data_loader.dataset, "volume_indices"):
+ all_filenames = list(data_loader.dataset.volume_indices.keys())
+ num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys()))
+ self.logger.info(
+ f"Reconstructing a total of {len(all_filenames)} volumes. "
+ f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})."
+ )
+
+ filenames_seen = 0
+ reconstruction_output: DefaultDict = defaultdict(list)
+ if is_validation_process:
+ targets_output: DefaultDict = defaultdict(list)
+ val_losses = []
+ val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict)
+ last_filename = None
+
+ # Container to for the slices which can be visualized in TensorBoard.
+ visualize_slices: List[np.ndarray] = []
+ visualize_target: List[np.ndarray] = []
+ # visualizations = {}
+
+ extra_visualization_keys = (
+ self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore
+ )
+
+ # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
+ # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
+ # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only
+ # contains data from one volume.
+ time_start = time.time()
+
+ for iter_idx, data in enumerate(data_loader):
+ filenames = data.pop("filename")
+ if len(set(filenames)) != 1:
+ raise ValueError(
+ f"Expected a batch during validation to only contain filenames of one case. "
+ f"Got {set(filenames)}."
+ )
+
+ slice_nos = data.pop("slice_no")
+ scaling_factors = data["scaling_factor"]
+
+ resolution = self.compute_resolution(
+ key=self.cfg.validation.crop, # type: ignore
+ reconstruction_size=data.get("reconstruction_size", None),
+ )
+
+ # Compute output and loss.
+ iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns)
+ output = iteration_output.output_image
+ loss_dict = iteration_output.data_dict
+
+ loss_dict = detach_dict(loss_dict)
+ output = output.detach()
+ val_losses.append(loss_dict)
+
+ # Output is complex-valued, and has to be cropped. This holds for both output and target.
+ # Output has shape (batch, complex, height, width)
+ output_abs = self.process_output(
+ output,
+ scaling_factors,
+ resolution=resolution,
+ )
+
+ if is_validation_process:
+ # Target has shape (batch, height, width)
+ target_abs = self.process_output(
+ data["target"].detach(),
+ scaling_factors,
+ resolution=resolution,
+ )
+ for key in extra_visualization_keys:
+ curr_data = data[key].detach()
+ # Here we need to discover which keys are actually normalized or not
+ # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23
+
+ del output # Explicitly call delete to clear memory.
+
+ # Aggregate volumes to be able to compute the metrics on complete volumes.
+ for idx, filename in enumerate(filenames):
+ if last_filename is None:
+ last_filename = filename # First iteration last_filename is not set.
+
+ curr_slice = output_abs[idx].detach()
+ slice_no = int(slice_nos[idx].numpy())
+
+ reconstruction_output[filename].append((slice_no, curr_slice.cpu()))
+
+ if is_validation_process:
+ targets_output[filename].append((slice_no, target_abs[idx].cpu()))
+
+ is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"])
+ reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch]
+ for condition in reconstruction_conditions:
+ if condition:
+ filenames_seen += 1
+
+ # Now we can ditch the reconstruction dict by reconstructing the volume,
+ # will take too much memory otherwise.
+ volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]])
+ if is_validation_process:
+ target = torch.stack([_[1] for _ in targets_output[last_filename]])
+ curr_metrics = {
+ metric_name: metric_fn(target, volume)
+ for metric_name, metric_fn in volume_metrics.items()
+ }
+ val_volume_metrics[last_filename] = curr_metrics
+ # Log the center slice of the volume
+ if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore
+ visualize_slices.append(volume[volume.shape[0] // 2])
+ visualize_target.append(target[target.shape[0] // 2])
+
+ # Delete outputs from memory, and recreate dictionary.
+ # This is not needed when not in validation as we are actually interested
+ # in the iteration output.
+ del targets_output[last_filename]
+ del reconstruction_output[last_filename]
+
+ if all_filenames:
+ log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
+ else:
+ log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"
+
+ self.logger.info(
+ f"{log_prefix} {last_filename}"
+ f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
+ )
+ # restart timer
+ time_start = time.time()
+ last_filename = filename
+
+ # Average loss dict
+ loss_dict = reduce_list_of_dicts(val_losses)
+ reduce_tensor_dict(loss_dict)
+
+ communication.synchronize()
+ torch.cuda.empty_cache()
+
+ all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics))
+ if not is_validation_process:
+ return loss_dict, reconstruction_output
+
+ return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ def process_output(self, data, scaling_factors=None, resolution=None):
+ # data is of shape (batch, complex=2, height, width)
+ if scaling_factors is not None:
+ data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device)
+
+ data = T.modulus_if_complex(data)
+
+ if len(data.shape) == 3: # (batch, height, width)
+ data = data.unsqueeze(1) # Added channel dimension.
+
+ if resolution is not None:
+ data = T.center_crop(data, resolution).contiguous()
+
+ return data
+
+ @staticmethod
+ def compute_resolution(key, reconstruction_size):
+ if key == "header":
+ # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over
+ # batches.
+ resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size]
+ # The volume sampler should give validation indices belonging to the *same* volume, so it should be
+ # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y).
+ resolution = [_[0] for _ in resolution][:-1]
+ elif key == "training":
+ resolution = key
+ elif not key:
+ resolution = None
+ else:
+ raise ValueError(
+ "Cropping should be either set to `header` to get the values from the header or "
+ "`training` to take the same value as training."
+ )
+ return resolution
+
+ def cropper(self, source, target, resolution):
+ """
+ 2D source/target cropper
+
+ Parameters:
+ -----------
+ Source has shape (batch, height, width)
+ Target has shape (batch, height, width)
+
+ """
+
+ if not resolution or all(_ == 0 for _ in resolution):
+ return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension.
+
+ source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension.
+ target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension.
+
+ return source_abs, target_abs
+
+ def compute_model_per_coil(self, model_name, data):
+ """
+ Computes model per coil.
+ """
+ # data is of shape (batch, coil, complex=2, height, width)
+ output = []
+
+ for idx in range(data.size(self._coil_dim)):
+ subselected_data = data.select(self._coil_dim, idx)
+ output.append(self.models[model_name](subselected_data))
+ output = torch.stack(output, dim=self._coil_dim)
+
+ # output is of shape (batch, coil, complex=2, height, width)
+ return output
diff --git a/direct/nn/kikinet/tests/__init__.py b/direct/nn/kikinet/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/kikinet/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/kikinet/tests/test_kikinet.py b/direct/nn/kikinet/tests/test_kikinet.py
new file mode 100644
index 00000000..9a7e70e9
--- /dev/null
+++ b/direct/nn/kikinet/tests/test_kikinet.py
@@ -0,0 +1,56 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import pytest
+import torch
+
+from direct.data.transforms import fft2, ifft2
+from direct.nn.kikinet.kikinet import KIKINet
+
+
+def create_input(shape):
+
+ data = torch.rand(shape).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [3, 3, 32, 32],
+ ],
+)
+@pytest.mark.parametrize(
+ "num_iter",
+ [1, 3],
+)
+@pytest.mark.parametrize(
+ "image_model_architecture",
+ ["MWCNN", "UNET", "NORMUNET"],
+)
+@pytest.mark.parametrize(
+ "kspace_model_architecture",
+ ["CONV", "DIDN", "UNET", "NORMUNET"],
+)
+@pytest.mark.parametrize(
+ "normalize",
+ [True, False],
+)
+def test_kikinet(shape, num_iter, image_model_architecture, kspace_model_architecture, normalize):
+ model = KIKINet(
+ fft2,
+ ifft2,
+ num_iter=num_iter,
+ image_model_architecture=image_model_architecture,
+ kspace_model_architecture=kspace_model_architecture,
+ normalize=normalize,
+ ).cpu()
+
+ kspace = create_input(shape + [2]).cpu()
+ mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu()
+ sens = create_input(shape + [2]).cpu()
+
+ out = model(kspace, mask, sens)
+
+ assert list(out.shape) == [shape[0]] + shape[2:] + [2]
diff --git a/direct/nn/lpd/config.py b/direct/nn/lpd/config.py
new file mode 100644
index 00000000..aab98eb0
--- /dev/null
+++ b/direct/nn/lpd/config.py
@@ -0,0 +1,31 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from dataclasses import dataclass
+
+from direct.config.defaults import ModelConfig
+
+
+@dataclass
+class LPDNetConfig(ModelConfig):
+ num_iter: int = 25
+ num_primal: int = 5
+ num_dual: int = 5
+ primal_model_architecture: str = "MWCNN"
+ dual_model_architecture: str = "DIDN"
+ primal_mwcnn_hidden_channels: int = 16
+ primal_mwcnn_num_scales: int = 4
+ primal_mwcnn_bias: bool = True
+ primal_mwcnn_batchnorm: bool = False
+ primal_unet_num_filters: int = 8
+ primal_unet_num_pool_layers: int = 4
+ primal_unet_dropout_probability: float = 0.0
+ dual_conv_hidden_channels: int = 16
+ dual_conv_n_convs: int = 4
+ dual_conv_batchnorm: bool = False
+ dual_didn_hidden_channels: int = 64
+ dual_didn_num_dubs: int = 6
+ dual_didn_num_convs_recon: int = 9
+ dual_unet_num_filters: int = 8
+ dual_unet_num_pool_layers: int = 4
+ dual_unet_dropout_probability: float = 0.0
diff --git a/direct/nn/lpd/lpd.py b/direct/nn/lpd/lpd.py
index 941752c9..3826408e 100644
--- a/direct/nn/lpd/lpd.py
+++ b/direct/nn/lpd/lpd.py
@@ -1,2 +1,263 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
+
+from typing import Callable
+
+import direct.data.transforms as T
+from direct.nn.conv.conv import Conv2d
+from direct.nn.didn.didn import DIDN
+from direct.nn.mwcnn.mwcnn import MWCNN
+from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d
+
+import torch
+import torch.nn as nn
+
+
+class DualNet(nn.Module):
+ """
+ Dual Network for Learned Primal Dual Network.
+ """
+
+ def __init__(self, num_dual, **kwargs):
+ super().__init__()
+
+ if kwargs.get("dual_architectue") is None:
+ n_hidden = kwargs.get("n_hidden")
+ if n_hidden is None:
+ raise ValueError("Missing argument n_hidden.")
+ self.dual_block = nn.Sequential(
+ *[
+ nn.Conv2d(2 * (num_dual + 2), n_hidden, kernel_size=3, padding=1),
+ nn.PReLU(),
+ nn.Conv2d(n_hidden, n_hidden, kernel_size=3, padding=1),
+ nn.PReLU(),
+ nn.Conv2d(n_hidden, 2 * num_dual, kernel_size=3, padding=1),
+ ]
+ )
+ else:
+ self.dual_block = kwargs.get("dual_architectue")
+
+ @staticmethod
+ def compute_model_per_coil(model, data):
+ """
+ Computes model per coil.
+ """
+ output = []
+ for idx in range(data.size(1)):
+ subselected_data = data.select(1, idx)
+ output.append(model(subselected_data))
+ output = torch.stack(output, dim=1)
+ return output
+
+ def forward(self, h, forward_f, g):
+ inp = torch.cat([h, forward_f, g], dim=-1).permute(0, 1, 4, 2, 3)
+ return self.compute_model_per_coil(self.dual_block, inp).permute(0, 1, 3, 4, 2)
+
+
+class PrimalNet(nn.Module):
+ """
+ Primal Network for Learned Primal Dual Network.
+ """
+
+ def __init__(self, num_primal, **kwargs):
+ super().__init__()
+
+ if kwargs.get("primal_architectue") is None:
+ n_hidden = kwargs.get("n_hidden")
+ if n_hidden is None:
+ raise ValueError("Missing argument n_hidden.")
+ self.primal_block = nn.Sequential(
+ *[
+ nn.Conv2d(2 * (num_primal + 1), n_hidden, kernel_size=3, padding=1),
+ nn.PReLU(),
+ nn.Conv2d(n_hidden, n_hidden, kernel_size=3, padding=1),
+ nn.PReLU(),
+ nn.Conv2d(n_hidden, 2 * num_primal, kernel_size=3, padding=1),
+ ]
+ )
+ else:
+ self.primal_block = kwargs.get("primal_architectue")
+
+ def forward(self, f, backward_h):
+ inp = torch.cat([f, backward_h], dim=-1).permute(0, 3, 1, 2)
+ return self.primal_block(inp).permute(0, 2, 3, 1)
+
+
+class LPDNet(nn.Module):
+ """
+ Learned Primal Dual network implementation as in https://arxiv.org/abs/1707.06474.
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable,
+ backward_operator: Callable,
+ num_iter: int,
+ num_primal: int,
+ num_dual: int,
+ primal_model_architecture: str = "MWCNN",
+ dual_model_architecture: str = "DIDN",
+ **kwargs,
+ ):
+ """
+
+ Parameters
+ ----------
+ forward_operator : Callable
+ Forward Operator.
+ backward_operator : Callable
+ Backward Operator.
+ num_iter : int
+ Number of unrolled iterations.
+ num_primal : int
+ Number of primal networks.
+ num_dual : int
+ Number of dual networks.
+ primal_model_architecture : str
+ Primal model architecture. Currently only implemented for MWCNN and (NORM)UNET. Default: 'MWCNN'.
+ dual_model_architecture : str
+ Dual model architecture. Currently only implemented for CONV and DIDN and (NORM)UNET. Default: 'DIDN'.
+ kwargs : dict
+ Keyword arguments for model architectures.
+ """
+ super().__init__()
+
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+ self.num_iter = num_iter
+ self.num_primal = num_primal
+ self.num_dual = num_dual
+
+ if primal_model_architecture == "MWCNN":
+ primal_model = nn.Sequential(
+ *[
+ MWCNN(
+ input_channels=2 * (num_primal + 1),
+ first_conv_hidden_channels=kwargs.get("primal_mwcnn_hidden_channels", 32),
+ num_scales=kwargs.get("primal_mwcnn_num_scales", 4),
+ bias=kwargs.get("primal_mwcnn_bias", False),
+ batchnorm=kwargs.get("primal_mwcnn_batchnorm", False),
+ ),
+ nn.Conv2d(2 * (num_primal + 1), 2 * num_primal, kernel_size=1),
+ ]
+ )
+ elif primal_model_architecture in ["UNET", "NORMUNET"]:
+ unet = UnetModel2d if primal_model_architecture == "UNET" else NormUnetModel2d
+ primal_model = unet(
+ in_channels=2 * (num_primal + 1),
+ out_channels=2 * num_primal,
+ num_filters=kwargs.get("primal_unet_num_filters", 8),
+ num_pool_layers=kwargs.get("primal_unet_num_pool_layers", 4),
+ dropout_probability=kwargs.get("primal_unet_dropout_probability", 0.0),
+ )
+ else:
+ raise NotImplementedError(
+ f"XPDNet is currently implemented only with primal_model_architecture == 'MWCNN', 'UNET' or 'NORMUNET."
+ f"Got {primal_model_architecture}."
+ )
+
+ if dual_model_architecture == "CONV":
+ dual_model = Conv2d(
+ in_channels=2 * (num_dual + 2),
+ out_channels=2 * num_dual,
+ hidden_channels=kwargs.get("dual_conv_hidden_channels", 16),
+ n_convs=kwargs.get("dual_conv_n_convs", 4),
+ batchnorm=kwargs.get("dual_conv_batchnorm", False),
+ )
+ elif dual_model_architecture == "DIDN":
+ dual_model = DIDN(
+ in_channels=2 * (num_dual + 2),
+ out_channels=2 * num_dual,
+ hidden_channels=kwargs.get("dual_didn_hidden_channels", 16),
+ num_dubs=kwargs.get("dual_didn_num_dubs", 6),
+ num_convs_recon=kwargs.get("dual_didn_num_convs_recon", 9),
+ )
+ elif dual_model_architecture in ["UNET", "NORMUNET"]:
+ unet = UnetModel2d if dual_model_architecture == "UNET" else NormUnetModel2d
+ dual_model = unet(
+ in_channels=2 * (num_dual + 2),
+ out_channels=2 * num_dual,
+ num_filters=kwargs.get("dual_unet_num_filters", 8),
+ num_pool_layers=kwargs.get("dual_unet_num_pool_layers", 4),
+ dropout_probability=kwargs.get("dual_unet_dropout_probability", 0.0),
+ )
+ else:
+ raise NotImplementedError(
+ f"XPDNet is currently implemented for dual_model_architecture == 'CONV', 'DIDN',"
+ f" 'UNET' or 'NORMUNET'. Got dual_model_architecture == {dual_model_architecture}."
+ )
+
+ self._coil_dim = 1
+ self._complex_dim = -1
+ self._spatial_dims = (2, 3)
+
+ self.primal_net = nn.ModuleList(
+ [PrimalNet(num_primal, primal_architectue=primal_model) for _ in range(num_iter)]
+ )
+ self.dual_net = nn.ModuleList([DualNet(num_dual, dual_architectue=dual_model) for _ in range(num_iter)])
+
+ def _forward_operator(self, image, sampling_mask, sensitivity_map):
+ forward = torch.where(
+ sampling_mask == 0,
+ torch.tensor([0.0], dtype=image.dtype).to(image.device),
+ self.forward_operator(T.expand_operator(image, sensitivity_map, self._coil_dim), dim=self._spatial_dims),
+ )
+ return forward
+
+ def _backward_operator(self, kspace, sampling_mask, sensitivity_map):
+ backward = T.reduce_operator(
+ self.backward_operator(
+ torch.where(
+ sampling_mask == 0,
+ torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device),
+ kspace,
+ ),
+ self._spatial_dims,
+ ),
+ sensitivity_map,
+ self._coil_dim,
+ )
+ return backward
+
+ def forward(
+ self,
+ masked_kspace: torch.Tensor,
+ sensitivity_map: torch.Tensor,
+ sampling_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+
+ Parameters
+ ----------
+ masked_kspace : torch.Tensor
+ Masked k-space of shape (N, coil, height, width, complex=2).
+ sensitivity_map : torch.Tensor
+ Sensitivity map of shape (N, coil, height, width, complex=2).
+ sampling_mask : torch.Tensor
+ Sampling mask of shape (N, 1, height, width, 1).
+
+ Returns
+ -------
+ output : torch.Tensor
+ Output image of shape (N, height, width, complex=2).
+ """
+ input_image = self._backward_operator(masked_kspace, sampling_mask, sensitivity_map)
+ dual_buffer = torch.cat([masked_kspace] * self.num_dual, self._complex_dim).to(masked_kspace.device)
+ primal_buffer = torch.cat([input_image] * self.num_primal, self._complex_dim).to(masked_kspace.device)
+
+ for iter in range(self.num_iter):
+
+ # Dual
+ f_2 = primal_buffer[..., 2:4].clone()
+ dual_buffer = self.dual_net[iter](
+ dual_buffer, self._forward_operator(f_2, sampling_mask, sensitivity_map), masked_kspace
+ )
+
+ # Primal
+ h_1 = dual_buffer[..., 0:2].clone()
+ primal_buffer = self.primal_net[iter](
+ primal_buffer, self._backward_operator(h_1, sampling_mask, sensitivity_map)
+ )
+
+ output = primal_buffer[..., 0:2]
+ return output
diff --git a/direct/nn/lpd/lpd_engine.py b/direct/nn/lpd/lpd_engine.py
new file mode 100644
index 00000000..50c72874
--- /dev/null
+++ b/direct/nn/lpd/lpd_engine.py
@@ -0,0 +1,469 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import time
+from collections import defaultdict
+from os import PathLike
+from typing import Callable, DefaultDict, Dict, List, Optional
+
+import numpy as np
+import torch
+from torch import nn
+from torch.cuda.amp import autocast
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+
+import direct.data.transforms as T
+from direct.config import BaseConfig
+from direct.engine import DoIterationOutput, Engine
+from direct.functionals import SSIMLoss
+from direct.utils import (
+ communication,
+ detach_dict,
+ dict_to_device,
+ merge_list_of_dicts,
+ multiply_function,
+ reduce_list_of_dicts,
+)
+from direct.utils.communication import reduce_tensor_dict
+
+
+class LPDNetEngine(Engine):
+ """
+ LPDNet Engine.
+ """
+
+ def __init__(
+ self,
+ cfg: BaseConfig,
+ model: nn.Module,
+ device: int,
+ forward_operator: Optional[Callable] = None,
+ backward_operator: Optional[Callable] = None,
+ mixed_precision: bool = False,
+ **models: nn.Module,
+ ):
+ super().__init__(
+ cfg,
+ model,
+ device,
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ mixed_precision=mixed_precision,
+ **models,
+ )
+
+ self._complex_dim = -1
+ self._coil_dim = 1
+ self._spatial_dims = (2, 3)
+
+ def _do_iteration(
+ self,
+ data: Dict[str, torch.Tensor],
+ loss_fns: Optional[Dict[str, Callable]] = None,
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ ) -> DoIterationOutput:
+
+ # loss_fns can be done, e.g. during validation
+ if loss_fns is None:
+ loss_fns = {}
+
+ if regularizer_fns is None:
+ regularizer_fns = {}
+
+ loss_dicts = []
+ regularizer_dicts = []
+
+ data = dict_to_device(data, self.device)
+
+ # sensitivity_map of shape (batch, coil, height, width, complex=2)
+ sensitivity_map = data["sensitivity_map"]
+
+ if "sensitivity_model" in self.models:
+
+ # Move channels to first axis
+ sensitivity_map = data["sensitivity_map"].permute((0, 1, 4, 2, 3))
+
+ sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute(
+ (0, 1, 3, 4, 2)
+ )
+
+ # The sensitivity map needs to be normalized such that
+ # So \sum_{i \in \text{coils}} S_i S_i^* = 1
+
+ sensitivity_map_norm = torch.sqrt(
+ ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim)
+ ) # shape (batch, height, width)
+ sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1)
+ data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm)
+
+ with autocast(enabled=self.mixed_precision):
+
+ output_image = self.model(
+ masked_kspace=data["masked_kspace"],
+ sampling_mask=data["sampling_mask"],
+ sensitivity_map=data["sensitivity_map"],
+ ) # shape (batch, height, width, complex=2)
+
+ output_image = T.modulus(output_image) # shape (batch, height, width)
+
+ loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()}
+ regularizer_dict = {
+ k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys()
+ }
+
+ for key, value in loss_dict.items():
+ loss_dict[key] = value + loss_fns[key](
+ output_image,
+ **data,
+ reduction="mean",
+ )
+
+ for key, value in regularizer_dict.items():
+ regularizer_dict[key] = value + regularizer_fns[key](
+ output_image,
+ **data,
+ )
+
+ loss = sum(loss_dict.values()) + sum(regularizer_dict.values())
+
+ if self.model.training:
+ self._scaler.scale(loss).backward()
+
+ loss_dicts.append(detach_dict(loss_dict))
+ regularizer_dicts.append(
+ detach_dict(regularizer_dict)
+ ) # Need to detach dict as this is only used for logging.
+
+ # Add the loss dicts.
+ loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum")
+ regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum")
+
+ return DoIterationOutput(
+ output_image=output_image,
+ sensitivity_map=data["sensitivity_map"],
+ data_dict={**loss_dict, **regularizer_dict},
+ )
+
+ def build_loss(self, **kwargs) -> Dict:
+ # TODO: Cropper is a processing output tool.
+ def get_resolution(**data):
+ """Be careful that this will use the cropping size of the FIRST sample in the batch."""
+ return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None))
+
+ def l1_loss(source, reduction="mean", **data):
+ """
+ Calculate L1 loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l1_loss
+
+ def l2_loss(source, reduction="mean", **data):
+ """
+ Calculate L2 loss (MSE) given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l2_loss
+
+ def ssim_loss(source, reduction="mean", **data):
+ """
+ Calculate SSIM loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ if reduction != "mean":
+ raise AssertionError(
+ f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}."
+ )
+
+ source_abs, target_abs = self.cropper(source, data["target"], resolution)
+ data_range = torch.tensor([target_abs.max()], device=target_abs.device)
+
+ ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range)
+
+ return ssim_loss
+
+ # Build losses
+ loss_dict = {}
+ for curr_loss in self.cfg.training.loss.losses: # type: ignore
+ loss_fn = curr_loss.function
+ if loss_fn == "l1_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss)
+ elif loss_fn == "l2_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss)
+ elif loss_fn == "ssim_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss)
+ else:
+ raise ValueError(f"{loss_fn} not permissible.")
+
+ return loss_dict
+
+ @torch.no_grad()
+ def evaluate(
+ self,
+ data_loader: DataLoader,
+ loss_fns: Optional[Dict[str, Callable]],
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ crop: Optional[str] = None,
+ is_validation_process: bool = True,
+ ):
+ """
+ Validation process. Assumes that each batch only contains slices of the same volume *AND* that these
+ are sequentially ordered.
+
+ Parameters
+ ----------
+ data_loader : DataLoader
+ loss_fns : Dict[str, Callable], optional
+ regularizer_fns : Dict[str, Callable], optional
+ crop : str, optional
+ is_validation_process : bool
+
+ Returns
+ -------
+ loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ """
+ self.models_to_device()
+ self.models_validation_mode()
+ torch.cuda.empty_cache()
+
+ # Variables required for evaluation.
+ volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore
+
+ # filenames can be in the volume_indices attribute of the dataset
+ num_for_this_process = None
+ all_filenames = None
+ if hasattr(data_loader.dataset, "volume_indices"):
+ all_filenames = list(data_loader.dataset.volume_indices.keys())
+ num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys()))
+ self.logger.info(
+ f"Reconstructing a total of {len(all_filenames)} volumes. "
+ f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})."
+ )
+
+ filenames_seen = 0
+ reconstruction_output: DefaultDict = defaultdict(list)
+ if is_validation_process:
+ targets_output: DefaultDict = defaultdict(list)
+ val_losses = []
+ val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict)
+ last_filename = None
+
+ # Container to for the slices which can be visualized in TensorBoard.
+ visualize_slices: List[np.ndarray] = []
+ visualize_target: List[np.ndarray] = []
+ # visualizations = {}
+
+ extra_visualization_keys = (
+ self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore
+ )
+
+ # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
+ # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
+ # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only
+ # contains data from one volume.
+ time_start = time.time()
+
+ for iter_idx, data in enumerate(data_loader):
+ filenames = data.pop("filename")
+ if len(set(filenames)) != 1:
+ raise ValueError(
+ f"Expected a batch during validation to only contain filenames of one case. "
+ f"Got {set(filenames)}."
+ )
+
+ slice_nos = data.pop("slice_no")
+ scaling_factors = data["scaling_factor"]
+
+ resolution = self.compute_resolution(
+ key=self.cfg.validation.crop, # type: ignore
+ reconstruction_size=data.get("reconstruction_size", None),
+ )
+
+ # Compute output and loss.
+ iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns)
+ output = iteration_output.output_image
+ loss_dict = iteration_output.data_dict
+
+ loss_dict = detach_dict(loss_dict)
+ output = output.detach()
+ val_losses.append(loss_dict)
+
+ # Output is complex-valued, and has to be cropped. This holds for both output and target.
+ # Output has shape (batch, complex, height, width)
+ output_abs = self.process_output(
+ output,
+ scaling_factors,
+ resolution=resolution,
+ )
+
+ if is_validation_process:
+ # Target has shape (batch, height, width)
+ target_abs = self.process_output(
+ data["target"].detach(),
+ scaling_factors,
+ resolution=resolution,
+ )
+ for key in extra_visualization_keys:
+ curr_data = data[key].detach()
+ # Here we need to discover which keys are actually normalized or not
+ # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23
+
+ del output # Explicitly call delete to clear memory.
+
+ # Aggregate volumes to be able to compute the metrics on complete volumes.
+ for idx, filename in enumerate(filenames):
+ if last_filename is None:
+ last_filename = filename # First iteration last_filename is not set.
+
+ curr_slice = output_abs[idx].detach()
+ slice_no = int(slice_nos[idx].numpy())
+
+ reconstruction_output[filename].append((slice_no, curr_slice.cpu()))
+
+ if is_validation_process:
+ targets_output[filename].append((slice_no, target_abs[idx].cpu()))
+
+ is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"])
+ reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch]
+ for condition in reconstruction_conditions:
+ if condition:
+ filenames_seen += 1
+
+ # Now we can ditch the reconstruction dict by reconstructing the volume,
+ # will take too much memory otherwise.
+ volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]])
+ if is_validation_process:
+ target = torch.stack([_[1] for _ in targets_output[last_filename]])
+ curr_metrics = {
+ metric_name: metric_fn(target, volume)
+ for metric_name, metric_fn in volume_metrics.items()
+ }
+ val_volume_metrics[last_filename] = curr_metrics
+ # Log the center slice of the volume
+ if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore
+ visualize_slices.append(volume[volume.shape[0] // 2])
+ visualize_target.append(target[target.shape[0] // 2])
+
+ # Delete outputs from memory, and recreate dictionary.
+ # This is not needed when not in validation as we are actually interested
+ # in the iteration output.
+ del targets_output[last_filename]
+ del reconstruction_output[last_filename]
+
+ if all_filenames:
+ log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
+ else:
+ log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"
+
+ self.logger.info(
+ f"{log_prefix} {last_filename}"
+ f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
+ )
+ # restart timer
+ time_start = time.time()
+ last_filename = filename
+
+ # Average loss dict
+ loss_dict = reduce_list_of_dicts(val_losses)
+ reduce_tensor_dict(loss_dict)
+
+ communication.synchronize()
+ torch.cuda.empty_cache()
+
+ all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics))
+ if not is_validation_process:
+ return loss_dict, reconstruction_output
+
+ return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ def process_output(self, data, scaling_factors=None, resolution=None):
+ # data is of shape (batch, complex=2, height, width)
+ if scaling_factors is not None:
+ data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device)
+
+ data = T.modulus_if_complex(data)
+
+ if len(data.shape) == 3: # (batch, height, width)
+ data = data.unsqueeze(1) # Added channel dimension.
+
+ if resolution is not None:
+ data = T.center_crop(data, resolution).contiguous()
+
+ return data
+
+ @staticmethod
+ def compute_resolution(key, reconstruction_size):
+ if key == "header":
+ # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over
+ # batches.
+ resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size]
+ # The volume sampler should give validation indices belonging to the *same* volume, so it should be
+ # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y).
+ resolution = [_[0] for _ in resolution][:-1]
+ elif key == "training":
+ resolution = key
+ elif not key:
+ resolution = None
+ else:
+ raise ValueError(
+ "Cropping should be either set to `header` to get the values from the header or "
+ "`training` to take the same value as training."
+ )
+ return resolution
+
+ def cropper(self, source, target, resolution):
+ """
+ 2D source/target cropper
+
+ Parameters:
+ -----------
+ Source has shape (batch, height, width)
+ Target has shape (batch, height, width)
+
+ """
+
+ if not resolution or all(_ == 0 for _ in resolution):
+ return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension.
+
+ source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension.
+ target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension.
+
+ return source_abs, target_abs
+
+ def compute_model_per_coil(self, model_name, data):
+ """
+ Computes model per coil.
+ """
+ # data is of shape (batch, coil, complex=2, height, width)
+ output = []
+
+ for idx in range(data.size(self._coil_dim)):
+ subselected_data = data.select(self._coil_dim, idx)
+ output.append(self.models[model_name](subselected_data))
+ output = torch.stack(output, dim=self._coil_dim)
+
+ # output is of shape (batch, coil, complex=2, height, width)
+ return output
diff --git a/direct/nn/lpd/tests/__init__.py b/direct/nn/lpd/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/lpd/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/lpd/tests/test_lpd.py b/direct/nn/lpd/tests/test_lpd.py
new file mode 100644
index 00000000..b86849f2
--- /dev/null
+++ b/direct/nn/lpd/tests/test_lpd.py
@@ -0,0 +1,68 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import pytest
+import torch
+
+from direct.data.transforms import fft2, ifft2
+from direct.nn.lpd.lpd import LPDNet
+
+
+def create_input(shape):
+
+ data = torch.rand(shape).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [3, 3, 32, 32],
+ ],
+)
+@pytest.mark.parametrize(
+ "num_iter",
+ [2, 3],
+)
+@pytest.mark.parametrize(
+ "num_primal",
+ [2, 3],
+)
+@pytest.mark.parametrize(
+ "num_dual",
+ [3],
+)
+@pytest.mark.parametrize(
+ "primal_model_architecture",
+ ["MWCNN", "UNET", "NORMUNET"],
+)
+@pytest.mark.parametrize(
+ "dual_model_architecture",
+ ["CONV", "DIDN", "UNET", "NORMUNET"],
+)
+def test_lpd(
+ shape,
+ num_iter,
+ num_primal,
+ num_dual,
+ primal_model_architecture,
+ dual_model_architecture,
+):
+ model = LPDNet(
+ fft2,
+ ifft2,
+ num_iter=num_iter,
+ num_primal=num_primal,
+ num_dual=num_dual,
+ primal_model_architecture=primal_model_architecture,
+ dual_model_architecture=dual_model_architecture,
+ ).cpu()
+
+ kspace = create_input(shape + [2]).cpu()
+ sens = create_input(shape + [2]).cpu()
+ mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu()
+
+ out = model(kspace, sens, mask)
+
+ assert list(out.shape) == [shape[0]] + shape[2:] + [2]
diff --git a/direct/nn/mobilenet/mobilenet.py b/direct/nn/mobilenet/mobilenet.py
index f7cb5d90..44930e31 100644
--- a/direct/nn/mobilenet/mobilenet.py
+++ b/direct/nn/mobilenet/mobilenet.py
@@ -18,10 +18,6 @@ def _make_divisible(v, divisor, min_value=None):
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
- :param v:
- :param divisor:
- :param min_value:
- :return:
"""
if min_value is None:
min_value = divisor
@@ -37,7 +33,7 @@ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, nor
padding = (kernel_size - 1) // 2
if norm_layer is None:
norm_layer = nn.BatchNorm2d
- super(ConvBNReLU, self).__init__(
+ super().__init__(
nn.Conv2d(
in_planes,
out_planes,
@@ -54,7 +50,8 @@ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, nor
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
- super(InvertedResidual, self).__init__()
+ super().__init__()
+
self.stride = stride
if stride not in [1, 2]:
raise AssertionError
@@ -123,7 +120,8 @@ def __init__(
norm_layer : str
Module specifying the normalization layer to use.
"""
- super(MobileNetV2, self).__init__()
+
+ super().__init__()
if block is None:
block = InvertedResidual
@@ -163,8 +161,8 @@ def __init__(
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
- for i in range(n):
- stride = s if i == 0 else 1
+ for idx in range(n):
+ stride = s if idx == 0 else 1
features.append(
block(
input_channel,
diff --git a/direct/nn/multidomainnet/__init__.py b/direct/nn/multidomainnet/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/multidomainnet/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/multidomainnet/config.py b/direct/nn/multidomainnet/config.py
new file mode 100644
index 00000000..c46ab526
--- /dev/null
+++ b/direct/nn/multidomainnet/config.py
@@ -0,0 +1,14 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+from direct.config.defaults import ModelConfig
+
+
+@dataclass
+class MultiDomainNetConfig(ModelConfig):
+ standardization: bool = True
+ num_filters: int = 16
+ num_pool_layers: int = 4
+ dropout_probability: float = 0.0
diff --git a/direct/nn/multidomainnet/multidomain.py b/direct/nn/multidomainnet/multidomain.py
new file mode 100644
index 00000000..51505285
--- /dev/null
+++ b/direct/nn/multidomainnet/multidomain.py
@@ -0,0 +1,309 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from typing import Callable
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MultiDomainConv2d(nn.Module):
+ def __init__(
+ self,
+ forward_operator,
+ backward_operator,
+ in_channels,
+ out_channels,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.image_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs)
+ self.kspace_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs)
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+ self._channels_dim = 1
+ self._spatial_dims = (1, 2)
+
+ def forward(self, image):
+ kspace = [
+ self.forward_operator(
+ im,
+ dim=self._spatial_dims,
+ )
+ for im in torch.split(image.permute(0, 2, 3, 1).contiguous(), 2, -1)
+ ]
+ kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2)
+ kspace = self.kspace_conv(kspace)
+
+ backward = [
+ self.backward_operator(
+ ks,
+ dim=self._spatial_dims,
+ )
+ for ks in torch.split(kspace.permute(0, 2, 3, 1).contiguous(), 2, -1)
+ ]
+ backward = torch.cat(backward, -1).permute(0, 3, 1, 2)
+
+ image = self.image_conv(image)
+ image = torch.cat([image, backward], dim=self._channels_dim)
+ return image
+
+
+class MultiDomainConvTranspose2d(nn.Module):
+ def __init__(
+ self,
+ forward_operator,
+ backward_operator,
+ in_channels,
+ out_channels,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.image_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs)
+ self.kspace_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs)
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+ self._channels_dim = 1
+ self._spatial_dims = (1, 2)
+
+ def forward(self, image):
+ kspace = [
+ self.forward_operator(
+ im,
+ dim=self._spatial_dims,
+ )
+ for im in torch.split(image.permute(0, 2, 3, 1).contiguous(), 2, -1)
+ ]
+ kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2)
+ kspace = self.kspace_conv(kspace)
+
+ backward = [
+ self.backward_operator(
+ ks,
+ dim=self._spatial_dims,
+ )
+ for ks in torch.split(kspace.permute(0, 2, 3, 1).contiguous(), 2, -1)
+ ]
+ backward = torch.cat(backward, -1).permute(0, 3, 1, 2)
+
+ image = self.image_conv(image)
+ return torch.cat([image, backward], dim=self._channels_dim)
+
+
+class MultiDomainConvBlock(nn.Module):
+ """
+ A multi-domain convolutional block that consists of two multi-domain convolution layers each followed by
+ instance normalization, LeakyReLU activation and dropout.
+ """
+
+ def __init__(
+ self, forward_operator, backward_operator, in_channels: int, out_channels: int, dropout_probability: float
+ ):
+ """
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels.
+ out_channels : int
+ Number of output channels.
+ dropout_probability : float
+ Dropout probability.
+ """
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.dropout_probability = dropout_probability
+
+ self.layers = nn.Sequential(
+ MultiDomainConv2d(
+ forward_operator, backward_operator, in_channels, out_channels, kernel_size=3, padding=1, bias=False
+ ),
+ nn.InstanceNorm2d(out_channels),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Dropout2d(dropout_probability),
+ MultiDomainConv2d(
+ forward_operator, backward_operator, out_channels, out_channels, kernel_size=3, padding=1, bias=False
+ ),
+ nn.InstanceNorm2d(out_channels),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Dropout2d(dropout_probability),
+ )
+
+ def forward(self, input: torch.Tensor):
+ """
+
+ Parameters
+ ----------
+ input : torch.Tensor
+
+ Returns
+ -------
+ torch.Tensor
+ """
+ return self.layers(input)
+
+ def __repr__(self):
+ return (
+ f"MultiDomainConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels}, "
+ f"dropout_probability={self.dropout_probability})"
+ )
+
+
+class TransposeMultiDomainConvBlock(nn.Module):
+ """
+ A Transpose Convolutional Block that consists of one convolution transpose layers followed by
+ instance normalization and LeakyReLU activation.
+ """
+
+ def __init__(self, forward_operator, backward_operator, in_channels: int, out_channels: int):
+ """
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels.
+ out_channels : int
+ Number of output channels.
+ """
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.layers = nn.Sequential(
+ MultiDomainConvTranspose2d(
+ forward_operator, backward_operator, in_channels, out_channels, kernel_size=2, stride=2, bias=False
+ ),
+ nn.InstanceNorm2d(out_channels),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ )
+
+ def forward(self, input: torch.Tensor):
+ """
+
+ Parameters
+ ----------
+ input : torch.Tensor
+
+ Returns
+ -------
+ torch.Tensor
+ """
+ return self.layers(input)
+
+ def __repr__(self):
+ return f"MultiDomainConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels})"
+
+
+class MultiDomainUnet2d(nn.Module):
+ """
+ Unet modification to be used with Multi-domain network as in AIRS Medical submission to the Fast MRI 2020 challenge.
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable,
+ backward_operator: Callable,
+ in_channels: int,
+ out_channels: int,
+ num_filters: int,
+ num_pool_layers: int,
+ dropout_probability: float,
+ ):
+ """
+
+ Parameters
+ ----------
+ forward_operator : Callable
+ Forward Operator.
+ backward_operator : Callable
+ Backward Operator.
+ in_channels : int
+ Number of input channels to the u-net.
+ out_channels : int
+ Number of output channels to the u-net.
+ num_filters : int
+ Number of output channels of the first convolutional layer.
+ num_pool_layers : int
+ Number of down-sampling and up-sampling layers (depth).
+ dropout_probability : float
+ Dropout probability.
+ """
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_filters = num_filters
+ self.num_pool_layers = num_pool_layers
+ self.dropout_probability = dropout_probability
+
+ self.down_sample_layers = nn.ModuleList(
+ [MultiDomainConvBlock(forward_operator, backward_operator, in_channels, num_filters, dropout_probability)]
+ )
+ ch = num_filters
+ for _ in range(num_pool_layers - 1):
+ self.down_sample_layers += [
+ MultiDomainConvBlock(forward_operator, backward_operator, ch, ch * 2, dropout_probability)
+ ]
+ ch *= 2
+ self.conv = MultiDomainConvBlock(forward_operator, backward_operator, ch, ch * 2, dropout_probability)
+
+ self.up_conv = nn.ModuleList()
+ self.up_transpose_conv = nn.ModuleList()
+ for _ in range(num_pool_layers - 1):
+ self.up_transpose_conv += [TransposeMultiDomainConvBlock(forward_operator, backward_operator, ch * 2, ch)]
+ self.up_conv += [
+ MultiDomainConvBlock(forward_operator, backward_operator, ch * 2, ch, dropout_probability)
+ ]
+ ch //= 2
+
+ self.up_transpose_conv += [TransposeMultiDomainConvBlock(forward_operator, backward_operator, ch * 2, ch)]
+ self.up_conv += [
+ nn.Sequential(
+ MultiDomainConvBlock(forward_operator, backward_operator, ch * 2, ch, dropout_probability),
+ nn.Conv2d(ch, self.out_channels, kernel_size=1, stride=1),
+ )
+ ]
+
+ def forward(self, input: torch.Tensor):
+ """
+
+ Parameters
+ ----------
+ input : torch.Tensor
+
+ Returns
+ -------
+ torch.Tensor
+ """
+ stack = []
+ output = input
+
+ # Apply down-sampling layers
+ for _, layer in enumerate(self.down_sample_layers):
+ output = layer(output)
+ stack.append(output)
+ output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
+
+ output = self.conv(output)
+
+ # Apply up-sampling layers
+ for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
+ downsample_layer = stack.pop()
+ output = transpose_conv(output)
+
+ # Reflect pad on the right/bottom if needed to handle odd input dimensions.
+ padding = [0, 0, 0, 0]
+ if output.shape[-1] != downsample_layer.shape[-1]:
+ padding[1] = 1 # Padding right
+ if output.shape[-2] != downsample_layer.shape[-2]:
+ padding[3] = 1 # Padding bottom
+ if sum(padding) != 0:
+ output = F.pad(output, padding, "reflect")
+
+ output = torch.cat([output, downsample_layer], dim=1)
+ output = conv(output)
+
+ return output
diff --git a/direct/nn/multidomainnet/multidomainnet.py b/direct/nn/multidomainnet/multidomainnet.py
new file mode 100644
index 00000000..38969c22
--- /dev/null
+++ b/direct/nn/multidomainnet/multidomainnet.py
@@ -0,0 +1,131 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from typing import Callable
+
+import direct.data.transforms as T
+from direct.nn.multidomainnet.multidomain import MultiDomainUnet2d
+
+import torch
+import torch.nn as nn
+
+
+class StandardizationLayer(nn.Module):
+ """
+ Multi-channel data standardization method. Inspired by AIRS model submission to the Fast MRI 2020 challenge.
+ Given individual coil images :math: {x_i}_{i=1}^{N_c} and sensitivity coil maps :math: {S_i}_{i=1}^{N_c}
+ it returns
+ .. math::
+ {xres_i}_{i=1}^{N_c},
+ where :math: xres_i = [x_{sense}, xi - S_i \times x_{sense}]
+ and :math: x_{sense} = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i.
+
+ """
+
+ def __init__(self, coil_dim=1, channel_dim=-1):
+ super().__init__()
+ self.coil_dim = coil_dim
+ self.channel_dim = channel_dim
+
+ def forward(self, coil_images: torch.Tensor, sensitivity_map: torch.Tensor) -> torch.Tensor:
+ combined_image = T.reduce_operator(coil_images, sensitivity_map, self.coil_dim)
+ residual_image = combined_image.unsqueeze(self.coil_dim) - T.complex_multiplication(
+ sensitivity_map, combined_image.unsqueeze(self.coil_dim)
+ )
+ concat = torch.cat(
+ [
+ torch.cat([combined_image, residual_image.select(self.coil_dim, idx)], self.channel_dim).unsqueeze(
+ self.coil_dim
+ )
+ for idx in range(coil_images.size(self.coil_dim))
+ ],
+ self.coil_dim,
+ )
+ return concat
+
+
+class MultiDomainNet(nn.Module):
+ """
+ Feature-level multi-domain module. Inspired by AIRS Medical submission to the Fast MRI 2020 challenge.
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable,
+ backward_operator: Callable,
+ standardization: bool = True,
+ num_filters: int = 16,
+ num_pool_layers: int = 4,
+ dropout_probability: float = 0.0,
+ **kwargs,
+ ):
+ """
+
+ Parameters
+ ----------
+ forward_operator : Callable
+ Forward Operator.
+ backward_operator : Callable
+ Backward Operator.
+ standardization : bool
+ If True standardization is used. Default: True.
+ num_filters : int
+ Number of filters for the MultiDomainUnet module. Default: 16.
+ num_pool_layers : int
+ Number of pooling layers for the MultiDomainUnet module. Default: 4.
+ dropout_probability : float
+ Dropout probability for the MultiDomainUnet module. Default: 0.0.
+ """
+ super().__init__()
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+ self._coil_dim = 1
+ self._complex_dim = -1
+ self._spatial_dims = (2, 3)
+
+ if standardization:
+ self.standardization = StandardizationLayer(self._coil_dim, self._complex_dim)
+
+ self.unet = MultiDomainUnet2d(
+ forward_operator,
+ backward_operator,
+ in_channels=4 if standardization else 2, # if standardization, in_channels is 4 due to standardized input
+ out_channels=2,
+ num_filters=num_filters,
+ num_pool_layers=num_pool_layers,
+ dropout_probability=dropout_probability,
+ )
+
+ def _compute_model_per_coil(self, model, data):
+ """
+ Computes model per coil.
+ """
+ output = []
+ for idx in range(data.size(self._coil_dim)):
+ subselected_data = data.select(self._coil_dim, idx)
+ output.append(model(subselected_data))
+ output = torch.stack(output, dim=self._coil_dim)
+ return output
+
+ def forward(self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor) -> torch.Tensor:
+ """
+
+ Parameters
+ ----------
+ masked_kspace : torch.Tensor
+ Masked k-space of shape (N, coil, height, width, complex=2).
+ sensitivity_map : torch.Tensor
+ Sensitivity map of shape (N, coil, height, width, complex=2).
+
+ Returns
+ -------
+ output_image : torch.Tensor
+ Multi-coil output image of shape (N, coil, height, width, complex=2).
+ """
+ input_image = self.backward_operator(masked_kspace, dim=self._spatial_dims)
+ if hasattr(self, "standardization"):
+ input_image = self.standardization(input_image, sensitivity_map)
+ output_image = self._compute_model_per_coil(self.unet, input_image.permute(0, 1, 4, 2, 3)).permute(
+ 0, 1, 3, 4, 2
+ )
+ return output_image
diff --git a/direct/nn/multidomainnet/multidomainnet_engine.py b/direct/nn/multidomainnet/multidomainnet_engine.py
new file mode 100644
index 00000000..0f4c8b41
--- /dev/null
+++ b/direct/nn/multidomainnet/multidomainnet_engine.py
@@ -0,0 +1,470 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import time
+from collections import defaultdict
+from os import PathLike
+from typing import Callable, DefaultDict, Dict, List, Optional
+
+import numpy as np
+import torch
+from torch import nn
+from torch.cuda.amp import autocast
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+
+import direct.data.transforms as T
+from direct.config import BaseConfig
+from direct.engine import DoIterationOutput, Engine
+from direct.functionals import SSIMLoss
+from direct.utils import (
+ communication,
+ detach_dict,
+ dict_to_device,
+ merge_list_of_dicts,
+ multiply_function,
+ reduce_list_of_dicts,
+)
+from direct.utils.communication import reduce_tensor_dict
+
+
+class MultiDomainNetEngine(Engine):
+ """
+ Multi Domain Network Engine.
+ """
+
+ def __init__(
+ self,
+ cfg: BaseConfig,
+ model: nn.Module,
+ device: int,
+ forward_operator: Optional[Callable] = None,
+ backward_operator: Optional[Callable] = None,
+ mixed_precision: bool = False,
+ **models: nn.Module,
+ ):
+ super().__init__(
+ cfg,
+ model,
+ device,
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ mixed_precision=mixed_precision,
+ **models,
+ )
+
+ self._complex_dim = -1
+ self._coil_dim = 1
+ self._spatial_dims = (2, 3)
+
+ def _do_iteration(
+ self,
+ data: Dict[str, torch.Tensor],
+ loss_fns: Optional[Dict[str, Callable]] = None,
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ ) -> DoIterationOutput:
+
+ # loss_fns can be done, e.g. during validation
+ if loss_fns is None:
+ loss_fns = {}
+
+ if regularizer_fns is None:
+ regularizer_fns = {}
+
+ loss_dicts = []
+ regularizer_dicts = []
+
+ data = dict_to_device(data, self.device)
+
+ # sensitivity_map of shape (batch, coil, height, width, complex=2)
+ sensitivity_map = data["sensitivity_map"]
+
+ if "sensitivity_model" in self.models:
+
+ # Move channels to first axis
+ sensitivity_map = data["sensitivity_map"].permute(
+ (0, 1, 4, 2, 3)
+ ) # shape (batch, coil, complex=2, height, width)
+
+ sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute(
+ (0, 1, 3, 4, 2)
+ ) # has channel last: shape (batch, coil, height, width, complex=2)
+
+ # The sensitivity map needs to be normalized such that
+ # So \sum_{i \in \text{coils}} S_i S_i^* = 1
+
+ sensitivity_map_norm = torch.sqrt(
+ ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim)
+ ) # shape (batch, height, width)
+ sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1)
+ data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm)
+
+ with autocast(enabled=self.mixed_precision):
+
+ output_multicoil_image = self.model(
+ masked_kspace=data["masked_kspace"],
+ sensitivity_map=data["sensitivity_map"],
+ )
+
+ output_image = T.root_sum_of_squares(
+ output_multicoil_image, self._coil_dim, self._complex_dim
+ ) # shape (batch, height, width)
+
+ loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()}
+ regularizer_dict = {
+ k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys()
+ }
+
+ for key, value in loss_dict.items():
+ loss_dict[key] = value + loss_fns[key](
+ output_image,
+ **data,
+ reduction="mean",
+ )
+
+ for key, value in regularizer_dict.items():
+ regularizer_dict[key] = value + regularizer_fns[key](
+ output_image,
+ **data,
+ )
+
+ loss = sum(loss_dict.values()) + sum(regularizer_dict.values())
+
+ if self.model.training:
+ self._scaler.scale(loss).backward()
+
+ loss_dicts.append(detach_dict(loss_dict))
+ regularizer_dicts.append(
+ detach_dict(regularizer_dict)
+ ) # Need to detach dict as this is only used for logging.
+
+ # Add the loss dicts.
+ loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum")
+ regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum")
+
+ return DoIterationOutput(
+ output_image=output_image,
+ sensitivity_map=data["sensitivity_map"],
+ data_dict={**loss_dict, **regularizer_dict},
+ )
+
+ def build_loss(self, **kwargs) -> Dict:
+ # TODO: Cropper is a processing output tool.
+ def get_resolution(**data):
+ """Be careful that this will use the cropping size of the FIRST sample in the batch."""
+ return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None))
+
+ def l1_loss(source, reduction="mean", **data):
+ """
+ Calculate L1 loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l1_loss
+
+ def l2_loss(source, reduction="mean", **data):
+ """
+ Calculate L2 loss (MSE) given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l2_loss
+
+ def ssim_loss(source, reduction="mean", **data):
+ """
+ Calculate SSIM loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ if reduction != "mean":
+ raise AssertionError(
+ f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}."
+ )
+
+ source_abs, target_abs = self.cropper(source, data["target"], resolution)
+ data_range = torch.tensor([target_abs.max()], device=target_abs.device)
+
+ ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range)
+
+ return ssim_loss
+
+ # Build losses
+ loss_dict = {}
+ for curr_loss in self.cfg.training.loss.losses: # type: ignore
+ loss_fn = curr_loss.function
+ if loss_fn == "l1_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss)
+ elif loss_fn == "l2_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss)
+ elif loss_fn == "ssim_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss)
+ else:
+ raise ValueError(f"{loss_fn} not permissible.")
+
+ return loss_dict
+
+ @torch.no_grad()
+ def evaluate(
+ self,
+ data_loader: DataLoader,
+ loss_fns: Optional[Dict[str, Callable]],
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ crop: Optional[str] = None,
+ is_validation_process: bool = True,
+ ):
+ """
+ Validation process. Assumes that each batch only contains slices of the same volume *AND* that these
+ are sequentially ordered.
+
+ Parameters
+ ----------
+ data_loader : DataLoader
+ loss_fns : Dict[str, Callable], optional
+ regularizer_fns : Dict[str, Callable], optional
+ crop : str, optional
+ is_validation_process : bool
+
+ Returns
+ -------
+ loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ """
+ self.models_to_device()
+ self.models_validation_mode()
+ torch.cuda.empty_cache()
+
+ # Variables required for evaluation.
+ volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore
+
+ # filenames can be in the volume_indices attribute of the dataset
+ num_for_this_process = None
+ all_filenames = None
+ if hasattr(data_loader.dataset, "volume_indices"):
+ all_filenames = list(data_loader.dataset.volume_indices.keys())
+ num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys()))
+ self.logger.info(
+ f"Reconstructing a total of {len(all_filenames)} volumes. "
+ f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})."
+ )
+
+ filenames_seen = 0
+ reconstruction_output: DefaultDict = defaultdict(list)
+ if is_validation_process:
+ targets_output: DefaultDict = defaultdict(list)
+ val_losses = []
+ val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict)
+ last_filename = None
+
+ # Container to for the slices which can be visualized in TensorBoard.
+ visualize_slices: List[np.ndarray] = []
+ visualize_target: List[np.ndarray] = []
+ # visualizations = {}
+
+ extra_visualization_keys = (
+ self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore
+ )
+
+ # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
+ # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
+ # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only
+ # contains data from one volume.
+ time_start = time.time()
+
+ for iter_idx, data in enumerate(data_loader):
+ filenames = data.pop("filename")
+ if len(set(filenames)) != 1:
+ raise ValueError(
+ f"Expected a batch during validation to only contain filenames of one case. "
+ f"Got {set(filenames)}."
+ )
+
+ slice_nos = data.pop("slice_no")
+ scaling_factors = data["scaling_factor"]
+
+ resolution = self.compute_resolution(
+ key=self.cfg.validation.crop, # type: ignore
+ reconstruction_size=data.get("reconstruction_size", None),
+ )
+
+ # Compute output and loss.
+ iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns)
+ output = iteration_output.output_image
+ loss_dict = iteration_output.data_dict
+
+ loss_dict = detach_dict(loss_dict)
+ output = output.detach()
+ val_losses.append(loss_dict)
+
+ # Output is complex-valued, and has to be cropped. This holds for both output and target.
+ # Output has shape (batch, complex, height, width)
+ output_abs = self.process_output(
+ output,
+ scaling_factors,
+ resolution=resolution,
+ )
+
+ if is_validation_process:
+ # Target has shape (batch, height, width)
+ target_abs = self.process_output(
+ data["target"].detach(),
+ scaling_factors,
+ resolution=resolution,
+ )
+ for key in extra_visualization_keys:
+ curr_data = data[key].detach()
+ # Here we need to discover which keys are actually normalized or not
+ # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23
+
+ del output # Explicitly call delete to clear memory.
+
+ # Aggregate volumes to be able to compute the metrics on complete volumes.
+ for idx, filename in enumerate(filenames):
+ if last_filename is None:
+ last_filename = filename # First iteration last_filename is not set.
+
+ curr_slice = output_abs[idx].detach()
+ slice_no = int(slice_nos[idx].numpy())
+
+ reconstruction_output[filename].append((slice_no, curr_slice.cpu()))
+
+ if is_validation_process:
+ targets_output[filename].append((slice_no, target_abs[idx].cpu()))
+
+ is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"])
+ reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch]
+ for condition in reconstruction_conditions:
+ if condition:
+ filenames_seen += 1
+
+ # Now we can ditch the reconstruction dict by reconstructing the volume,
+ # will take too much memory otherwise.
+ volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]])
+ if is_validation_process:
+ target = torch.stack([_[1] for _ in targets_output[last_filename]])
+ curr_metrics = {
+ metric_name: metric_fn(target, volume)
+ for metric_name, metric_fn in volume_metrics.items()
+ }
+ val_volume_metrics[last_filename] = curr_metrics
+ # Log the center slice of the volume
+ if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore
+ visualize_slices.append(volume[volume.shape[0] // 2])
+ visualize_target.append(target[target.shape[0] // 2])
+
+ # Delete outputs from memory, and recreate dictionary.
+ # This is not needed when not in validation as we are actually interested
+ # in the iteration output.
+ del targets_output[last_filename]
+ del reconstruction_output[last_filename]
+
+ if all_filenames:
+ log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
+ else:
+ log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"
+
+ self.logger.info(
+ f"{log_prefix} {last_filename}"
+ f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
+ )
+ # restart timer
+ time_start = time.time()
+ last_filename = filename
+
+ # Average loss dict
+ loss_dict = reduce_list_of_dicts(val_losses)
+ reduce_tensor_dict(loss_dict)
+
+ communication.synchronize()
+ torch.cuda.empty_cache()
+
+ all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics))
+ if not is_validation_process:
+ return loss_dict, reconstruction_output
+
+ return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ def process_output(self, data, scaling_factors=None, resolution=None):
+ # data is of shape (batch, complex=2, height, width)
+ if scaling_factors is not None:
+ data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device)
+
+ data = T.modulus_if_complex(data)
+
+ if len(data.shape) == 3: # (batch, height, width)
+ data = data.unsqueeze(1) # Added channel dimension.
+
+ if resolution is not None:
+ data = T.center_crop(data, resolution).contiguous()
+
+ return data
+
+ @staticmethod
+ def compute_resolution(key, reconstruction_size):
+ if key == "header":
+ # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over
+ # batches.
+ resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size]
+ # The volume sampler should give validation indices belonging to the *same* volume, so it should be
+ # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y).
+ resolution = [_[0] for _ in resolution][:-1]
+ elif key == "training":
+ resolution = key
+ elif not key:
+ resolution = None
+ else:
+ raise ValueError(
+ "Cropping should be either set to `header` to get the values from the header or "
+ "`training` to take the same value as training."
+ )
+ return resolution
+
+ def cropper(self, source, target, resolution):
+ """
+ 2D source/target cropper
+
+ Parameters:
+ -----------
+ Source has shape (batch, height, width)
+ Target has shape (batch, height, width)
+
+ """
+
+ if not resolution or all(_ == 0 for _ in resolution):
+ return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension.
+
+ source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension.
+ target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension.
+
+ return source_abs, target_abs
+
+ def compute_model_per_coil(self, model_name, data):
+ """
+ Computes model per coil.
+ """
+ output = []
+
+ for idx in range(data.size(self._coil_dim)):
+ subselected_data = data.select(self._coil_dim, idx)
+ output.append(self.models[model_name](subselected_data))
+ output = torch.stack(output, dim=self._coil_dim)
+
+ return output
diff --git a/direct/nn/multidomainnet/tests/__init__.py b/direct/nn/multidomainnet/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/multidomainnet/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/multidomainnet/tests/test_multidomainnet.py b/direct/nn/multidomainnet/tests/test_multidomainnet.py
new file mode 100644
index 00000000..fecbdea4
--- /dev/null
+++ b/direct/nn/multidomainnet/tests/test_multidomainnet.py
@@ -0,0 +1,83 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import pytest
+import torch
+
+from direct.data.transforms import fft2, ifft2
+from direct.nn.multidomainnet.multidomain import MultiDomainUnet2d
+from direct.nn.multidomainnet.multidomainnet import MultiDomainNet
+
+
+def create_input(shape):
+
+ data = torch.rand(shape).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [2, 2, 16, 16],
+ [4, 2, 16, 32],
+ [3, 2, 32, 32],
+ [3, 2, 40, 20],
+ ],
+)
+@pytest.mark.parametrize(
+ "num_filters",
+ [4, 8, 16], # powers of 2
+)
+@pytest.mark.parametrize(
+ "num_pool_layers",
+ [2, 3],
+)
+def test_multidomainunet2d(shape, num_filters, num_pool_layers):
+ model = MultiDomainUnet2d(
+ fft2,
+ ifft2,
+ shape[1],
+ shape[1],
+ num_filters=num_filters,
+ num_pool_layers=num_pool_layers,
+ dropout_probability=0.05,
+ ).cpu()
+
+ data = create_input(shape).cpu()
+
+ out = model(data)
+
+ assert list(out.shape) == shape
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [2, 2, 16, 16],
+ [4, 2, 16, 32],
+ [3, 2, 32, 32],
+ [3, 2, 40, 20],
+ ],
+)
+@pytest.mark.parametrize("standardization", [True, False])
+@pytest.mark.parametrize(
+ "num_filters",
+ [4, 8], # powers of 2
+)
+@pytest.mark.parametrize(
+ "num_pool_layers",
+ [2, 3],
+)
+def test_multidomainnet(shape, standardization, num_filters, num_pool_layers):
+
+ model = MultiDomainNet(fft2, ifft2, standardization, num_filters, num_pool_layers)
+
+ shape = shape + [2]
+
+ kspace = create_input(shape).cpu()
+ sens = create_input(shape).cpu()
+
+ out = model(kspace, sens)
+
+ assert list(out.shape) == shape
diff --git a/direct/nn/mwcnn/__init__.py b/direct/nn/mwcnn/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/mwcnn/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/mwcnn/mwcnn.py b/direct/nn/mwcnn/mwcnn.py
new file mode 100644
index 00000000..d6a27c73
--- /dev/null
+++ b/direct/nn/mwcnn/mwcnn.py
@@ -0,0 +1,328 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from collections import OrderedDict
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DWT(nn.Module):
+ """
+ 2D Discrete Wavelet Transform as implemented in https://arxiv.org/abs/1805.07071.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.requires_grad = False
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x01 = x[:, :, 0::2, :] / 2
+ x02 = x[:, :, 1::2, :] / 2
+ x1 = x01[:, :, :, 0::2]
+ x2 = x02[:, :, :, 0::2]
+ x3 = x01[:, :, :, 1::2]
+ x4 = x02[:, :, :, 1::2]
+ x_LL = x1 + x2 + x3 + x4
+ x_HL = -x1 - x2 + x3 + x4
+ x_LH = -x1 + x2 - x3 + x4
+ x_HH = x1 - x2 - x3 + x4
+
+ return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
+
+
+class IWT(nn.Module):
+ """
+ 2D Inverse Wavelet Transform as implemented in https://arxiv.org/abs/1805.07071.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.requires_grad = False
+ self._r = 2
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+
+ batch, in_channel, in_height, in_width = x.size()
+ out_channel, out_height, out_width = int(in_channel / (self._r ** 2)), self._r * in_height, self._r * in_width
+
+ x1 = x[:, 0:out_channel, :, :] / 2
+ x2 = x[:, out_channel : out_channel * 2, :, :] / 2
+ x3 = x[:, out_channel * 2 : out_channel * 3, :, :] / 2
+ x4 = x[:, out_channel * 3 : out_channel * 4, :, :] / 2
+
+ h = torch.zeros([batch, out_channel, out_height, out_width], dtype=x.dtype).to(x.device)
+
+ h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
+ h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
+ h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
+ h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
+
+ return h
+
+
+class ConvBlock(nn.Module):
+ """
+ Convolution Block for MWCNN as implemented in https://arxiv.org/abs/1805.07071.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ bias: bool = True,
+ batchnorm: bool = False,
+ activation: nn.Module = nn.ReLU(True),
+ scale: Optional[float] = 1.0,
+ ):
+ super().__init__()
+
+ net = []
+ net.append(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ bias=bias,
+ padding=kernel_size // 2,
+ )
+ )
+ if batchnorm:
+ net.append(nn.BatchNorm2d(num_features=out_channels, eps=1e-4, momentum=0.95))
+ net.append(activation)
+
+ self.net = nn.Sequential(*net)
+ self.scale = scale
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ output = self.net(input) * self.scale
+ return output
+
+
+class DilatedConvBlock(nn.Module):
+ """
+ Double dilated Convolution Block fpr MWCNN as implemented in https://arxiv.org/abs/1805.07071.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ dilations: Tuple[int, int],
+ kernel_size: int,
+ out_channels: Optional[int] = None,
+ bias: bool = True,
+ batchnorm: bool = False,
+ activation: nn.Module = nn.ReLU(True),
+ scale: Optional[float] = 1.0,
+ ):
+ super().__init__()
+ net = []
+ net.append(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=kernel_size,
+ bias=bias,
+ dilation=dilations[0],
+ padding=kernel_size // 2 + dilations[0] - 1,
+ )
+ )
+ if batchnorm:
+ net.append(nn.BatchNorm2d(num_features=in_channels, eps=1e-4, momentum=0.95))
+ net.append(activation)
+ if out_channels is None:
+ out_channels = in_channels
+ net.append(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ bias=bias,
+ dilation=dilations[1],
+ padding=kernel_size // 2 + dilations[1] - 1,
+ )
+ )
+ if batchnorm:
+ net.append(nn.BatchNorm2d(num_features=in_channels, eps=1e-4, momentum=0.95))
+ net.append(activation)
+
+ self.net = nn.Sequential(*net)
+ self.scale = scale
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ output = self.net(input) * self.scale
+ return output
+
+
+class MWCNN(nn.Module):
+ """
+ Multi-level Wavelet CNN (MWCNN) implementation as implemented in https://arxiv.org/abs/1805.07071.
+ """
+
+ def __init__(
+ self,
+ input_channels: int,
+ first_conv_hidden_channels: int,
+ num_scales: int = 4,
+ bias: bool = True,
+ batchnorm: bool = False,
+ activation: nn.Module = nn.ReLU(True),
+ ):
+ """
+
+ Parameters
+ ----------
+ input_channels : int
+ Input channels dimension.
+ first_conv_hidden_channels : int
+ First convolution output channels dimension.
+ num_scales : int
+ Number of scales. Default: 4.
+ bias : bool
+ Convolution bias. If True, adds a learnable bias to the output. Default: True.
+ batchnorm : bool
+ If True, a batchnorm layer is added after each convolution. Default: False.
+ activation : nn.Module
+ Activation function applied after each convolution. Default: nn.ReLU().
+ """
+ super().__init__()
+ self._kernel_size = 3
+ self.DWT = DWT()
+ self.IWT = IWT()
+
+ self.down = nn.ModuleList()
+ for idx in range(0, num_scales):
+
+ in_channels = input_channels if idx == 0 else first_conv_hidden_channels * 2 ** (idx + 1)
+ out_channels = first_conv_hidden_channels * 2 ** idx
+ dilations = (2, 1) if idx != num_scales - 1 else (2, 3)
+ self.down.append(
+ nn.Sequential(
+ OrderedDict(
+ [
+ (
+ f"convblock{idx}",
+ ConvBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=self._kernel_size,
+ bias=bias,
+ batchnorm=batchnorm,
+ activation=activation,
+ ),
+ ),
+ (
+ f"dilconvblock{idx}",
+ DilatedConvBlock(
+ in_channels=out_channels,
+ dilations=dilations,
+ kernel_size=self._kernel_size,
+ bias=bias,
+ batchnorm=batchnorm,
+ activation=activation,
+ ),
+ ),
+ ]
+ )
+ )
+ )
+ self.up = nn.ModuleList()
+ for idx in range(num_scales)[::-1]:
+
+ in_channels = first_conv_hidden_channels * 2 ** idx
+ out_channels = input_channels if idx == 0 else first_conv_hidden_channels * 2 ** (idx + 1)
+ dilations = (2, 1) if idx != num_scales - 1 else (3, 2)
+ self.up.append(
+ nn.Sequential(
+ OrderedDict(
+ [
+ (
+ f"invdilconvblock{num_scales - 2 - idx}",
+ DilatedConvBlock(
+ in_channels=in_channels,
+ dilations=dilations,
+ kernel_size=self._kernel_size,
+ bias=bias,
+ batchnorm=batchnorm,
+ activation=activation,
+ ),
+ ),
+ (
+ f"invconvblock{num_scales - 2 - idx}",
+ ConvBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=self._kernel_size,
+ bias=bias,
+ batchnorm=batchnorm,
+ activation=activation,
+ ),
+ ),
+ ]
+ )
+ )
+ )
+ self.num_scales = num_scales
+
+ @staticmethod
+ def pad(x):
+ padding = [0, 0, 0, 0]
+
+ if x.shape[-2] % 2 != 0:
+ padding[3] = 1 # Padding right - width
+ if x.shape[-1] % 2 != 0:
+ padding[1] = 1 # Padding bottom - height
+ if sum(padding) != 0:
+ x = F.pad(x, padding, "reflect")
+ return x
+
+ @staticmethod
+ def crop_to_shape(x, shape):
+ h, w = x.shape[-2:]
+
+ if h > shape[0]:
+ x = x[:, :, : shape[0], :]
+ if w > shape[1]:
+ x = x[:, :, :, : shape[1]]
+ return x
+
+ def forward(self, input: torch.Tensor, res: bool = False) -> torch.Tensor:
+ """
+
+ Parameters
+ ----------
+ input : torch.Tensor
+ Input tensor.
+ res : bool
+ If True, residual connection is applied to the output. Default: False.
+
+ Returns
+ -------
+ torch.Tensor
+ """
+ res_values = []
+ x = self.pad(input.clone())
+ for idx in range(self.num_scales):
+ if idx == 0:
+ x = self.pad(self.down[idx](x))
+ res_values.append(x)
+ elif idx == self.num_scales - 1:
+ x = self.down[idx](self.DWT(x))
+ else:
+ x = self.pad(self.down[idx](self.DWT(x)))
+ res_values.append(x)
+
+ for idx in range(self.num_scales):
+ if idx != self.num_scales - 1:
+ x = (
+ self.crop_to_shape(self.IWT(self.up[idx](x)), res_values[self.num_scales - 2 - idx].shape[-2:])
+ + res_values[self.num_scales - 2 - idx]
+ )
+ else:
+ x = self.crop_to_shape(self.up[idx](x), input.shape[-2:])
+ if res:
+ x += input
+ return x
diff --git a/direct/nn/mwcnn/tests/__init__.py b/direct/nn/mwcnn/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/mwcnn/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/mwcnn/tests/test_mwcnn.py b/direct/nn/mwcnn/tests/test_mwcnn.py
new file mode 100644
index 00000000..dd05f7ec
--- /dev/null
+++ b/direct/nn/mwcnn/tests/test_mwcnn.py
@@ -0,0 +1,52 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import pytest
+import torch
+import torch.nn as nn
+
+from direct.nn.mwcnn.mwcnn import MWCNN
+
+
+def create_input(shape):
+
+ data = torch.rand(shape).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [3, 2, 32, 32],
+ [3, 2, 20, 34],
+ ],
+)
+@pytest.mark.parametrize(
+ "first_conv_hidden_channels",
+ [4, 8],
+)
+@pytest.mark.parametrize(
+ "n_scales",
+ [2, 3],
+)
+@pytest.mark.parametrize(
+ "bias",
+ [True, False],
+)
+@pytest.mark.parametrize(
+ "batchnorm",
+ [True, False],
+)
+@pytest.mark.parametrize(
+ "act",
+ [nn.ReLU(), nn.PReLU()],
+)
+def test_mwcnn(shape, first_conv_hidden_channels, n_scales, bias, batchnorm, act):
+ model = MWCNN(shape[1], first_conv_hidden_channels, n_scales, bias, batchnorm, act)
+
+ data = create_input(shape).cpu()
+
+ out = model(data)
+
+ assert list(out.shape) == shape
diff --git a/direct/nn/recurrent/__init__.py b/direct/nn/recurrent/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/recurrent/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/recurrent/recurrent.py b/direct/nn/recurrent/recurrent.py
new file mode 100644
index 00000000..05234934
--- /dev/null
+++ b/direct/nn/recurrent/recurrent.py
@@ -0,0 +1,163 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from typing import List, Optional, Tuple
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Conv2dGRU(nn.Module):
+ """
+ 2D Convolutional GRU Network.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ hidden_channels: int,
+ out_channels: Optional[int] = None,
+ num_layers: int = 2,
+ gru_kernel_size=1,
+ orthogonal_initialization: bool = True,
+ instance_norm: bool = False,
+ dense_connect: int = 0,
+ replication_padding: bool = True,
+ ):
+ """
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels.
+ hidden_channels : int
+ Number of hidden channels.
+ out_channels : Optional[int]
+ Number of output channels. If None, same as in_channels. Default: None.
+ num_layers : int
+ Number of layers. Default: 2.
+ gru_kernel_size : int
+ Size of the GRU kernel. Default: 1.
+ orthogonal_initialization : bool
+ Orthogonal initialization is used if set to True. Default: True.
+ instance_norm : bool
+ Instance norm is used if set to True. Default: False.
+ dense_connect : int
+ Number of dense connections.
+ replication_padding : bool
+ If set to true replication padding is applied.
+ """
+ super().__init__()
+
+ if out_channels is None:
+ out_channels = in_channels
+
+ self.num_layers = num_layers
+ self.hidden_channels = hidden_channels
+ self.dense_connect = dense_connect
+
+ self.reset_gates = nn.ModuleList([])
+ self.update_gates = nn.ModuleList([])
+ self.out_gates = nn.ModuleList([])
+ self.conv_blocks = nn.ModuleList([])
+
+ # Create convolutional blocks
+ for idx in range(num_layers + 1):
+ in_ch = in_channels if idx == 0 else (1 + min(idx, dense_connect)) * hidden_channels
+ out_ch = hidden_channels if idx < num_layers else out_channels
+ padding = 0 if replication_padding else (2 if idx == 0 else 1)
+ block = []
+ if replication_padding:
+ if idx == 1:
+ block.append(nn.ReplicationPad2d(2))
+ else:
+ block.append(nn.ReplicationPad2d(2 if idx == 0 else 1))
+ block.append(
+ nn.Conv2d(
+ in_channels=in_ch,
+ out_channels=out_ch,
+ kernel_size=5 if idx == 0 else 3,
+ dilation=(2 if idx == 1 else 1),
+ padding=padding,
+ )
+ )
+ self.conv_blocks.append(nn.Sequential(*block))
+
+ # Create GRU blocks
+ for idx in range(num_layers):
+ for gru_part in [self.reset_gates, self.update_gates, self.out_gates]:
+ block = []
+ if instance_norm:
+ block.append(nn.InstanceNorm2d(2 * hidden_channels))
+ block.append(
+ nn.Conv2d(
+ in_channels=2 * hidden_channels,
+ out_channels=hidden_channels,
+ kernel_size=gru_kernel_size,
+ padding=gru_kernel_size // 2,
+ )
+ )
+ gru_part.append(nn.Sequential(*block))
+
+ if orthogonal_initialization:
+ for reset_gate, update_gate, out_gate in zip(self.reset_gates, self.update_gates, self.out_gates):
+ nn.init.orthogonal_(reset_gate[-1].weight)
+ nn.init.orthogonal_(update_gate[-1].weight)
+ nn.init.orthogonal_(out_gate[-1].weight)
+ nn.init.constant_(reset_gate[-1].bias, -1.0)
+ nn.init.constant_(update_gate[-1].bias, 0.0)
+ nn.init.constant_(out_gate[-1].bias, 0.0)
+
+ def forward(
+ self,
+ cell_input: torch.Tensor,
+ previous_state: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+
+ Parameters
+ ----------
+ cell_input : torch.Tensor
+ Reconstruction input
+ previous_state : torch.Tensor
+ Tensor of previous states.
+
+ Returns
+ -------
+ (torch.Tensor, torch.Tensor)
+ """
+ new_states: List[torch.Tensor] = []
+ conv_skip: List[torch.Tensor] = []
+
+ if previous_state is None:
+ batch_size, spatial_size = cell_input.size(0), (cell_input.size(2), cell_input.size(3))
+ state_size = [batch_size, self.hidden_channels] + list(spatial_size) + [self.num_layers]
+ previous_state = torch.zeros(*state_size, dtype=cell_input.dtype).to(cell_input.device)
+
+ for idx in range(self.num_layers):
+ if len(conv_skip) > 0:
+ cell_input = F.relu(
+ self.conv_blocks[idx](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1)),
+ inplace=True,
+ )
+ else:
+ cell_input = F.relu(self.conv_blocks[idx](cell_input), inplace=True)
+ if self.dense_connect > 0:
+ conv_skip.append(cell_input)
+
+ stacked_inputs = torch.cat([cell_input, previous_state[:, :, :, :, idx]], dim=1)
+
+ update = torch.sigmoid(self.update_gates[idx](stacked_inputs))
+ reset = torch.sigmoid(self.reset_gates[idx](stacked_inputs))
+ delta = torch.tanh(
+ self.out_gates[idx](torch.cat([cell_input, previous_state[:, :, :, :, idx] * reset], dim=1))
+ )
+ cell_input = previous_state[:, :, :, :, idx] * (1 - update) + delta * update
+ new_states.append(cell_input)
+ cell_input = F.relu(cell_input, inplace=False)
+ if len(conv_skip) > 0:
+ out = self.conv_blocks[self.num_layers](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1))
+ else:
+ out = self.conv_blocks[self.num_layers](cell_input)
+
+ return out, torch.stack(new_states, dim=-1)
diff --git a/direct/nn/recurrent/tests/__init__.py b/direct/nn/recurrent/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/recurrent/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/recurrent/tests/test_recurrent.py b/direct/nn/recurrent/tests/test_recurrent.py
new file mode 100644
index 00000000..aad78200
--- /dev/null
+++ b/direct/nn/recurrent/tests/test_recurrent.py
@@ -0,0 +1,34 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import pytest
+import torch
+
+from direct.nn.recurrent.recurrent import Conv2dGRU
+
+
+def create_input(shape):
+
+ data = torch.rand(shape).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [3, 2, 32, 32],
+ [3, 2, 16, 16],
+ ],
+)
+@pytest.mark.parametrize(
+ "hidden_channels",
+ [4, 8],
+)
+def test_conv2dgru(shape, hidden_channels):
+ model = Conv2dGRU(shape[1], hidden_channels, shape[1])
+ data = create_input(shape).cpu()
+
+ out = model(data, None)[0]
+
+ assert list(out.shape) == shape
diff --git a/direct/nn/rim/rim.py b/direct/nn/rim/rim.py
index 6c20522c..6d903744 100644
--- a/direct/nn/rim/rim.py
+++ b/direct/nn/rim/rim.py
@@ -2,7 +2,7 @@
# Copyright (c) DIRECT Contributors
import warnings
-from typing import Callable, List, Optional, Tuple, Union
+from typing import Callable, Optional, Tuple
import numpy as np
import torch
@@ -11,136 +11,18 @@
from direct.data import transforms as T
from direct.utils.asserts import assert_positive_integer
+from direct.nn.recurrent.recurrent import Conv2dGRU
-class ConvGRUCell(nn.Module):
+class MRILogLikelihood(nn.Module):
"""
- Convolutional GRU Cell to be used with RIM (Recurrent Inference Machines).
+ Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils.
+ .. math::
+ \frac{1}{\sigma^2} \sum_{i}^{\text{num coils}}
+ {S}_i^\{text{H}} \mathcal{F}^{-1} P^T (P \mathcal{F} S_i x_\tau - y_\tau)
+ for each time step :math: \tau.
"""
- def __init__(
- self,
- x_channels: int,
- hidden_channels,
- depth=2,
- gru_kernel_size=1,
- ortho_init: bool = True,
- instance_norm: bool = False,
- dense_connect=0,
- replication_padding=False,
- ):
- super().__init__()
- self.depth = depth
- self.x_channels = x_channels
- self.hidden_channels = hidden_channels
- self.instance_norm = instance_norm
- self.dense_connect = dense_connect
- self.repl_pad = replication_padding
-
- self.reset_gates = nn.ModuleList([])
- self.update_gates = nn.ModuleList([])
- self.out_gates = nn.ModuleList([])
- self.conv_blocks = nn.ModuleList([])
-
- # Create convolutional blocks of RIM cell
- for idx in range(depth + 1):
- in_ch = x_channels + 2 if idx == 0 else (1 + min(idx, dense_connect)) * hidden_channels
- out_ch = hidden_channels if idx < depth else x_channels
- pad = 0 if replication_padding else (2 if idx == 0 else 1)
- block = []
- if replication_padding:
- if idx == 1:
- block.append(nn.ReplicationPad2d(2))
- else:
- block.append(nn.ReplicationPad2d(2 if idx == 0 else 1))
- block.append(
- nn.Conv2d(
- in_ch,
- out_ch,
- 5 if idx == 0 else 3,
- dilation=(2 if idx == 1 else 1),
- padding=pad,
- )
- )
- self.conv_blocks.append(nn.Sequential(*block))
-
- # Create GRU blocks of RIM cell
- for idx in range(depth):
- for gru_part in [self.reset_gates, self.update_gates, self.out_gates]:
- block = []
- if instance_norm:
- block.append(nn.InstanceNorm2d(2 * hidden_channels))
- block.append(
- nn.Conv2d(
- 2 * hidden_channels,
- hidden_channels,
- gru_kernel_size,
- padding=gru_kernel_size // 2,
- )
- )
- gru_part.append(nn.Sequential(*block))
-
- if ortho_init:
- for reset_gate, update_gate, out_gate in zip(self.reset_gates, self.update_gates, self.out_gates):
- nn.init.orthogonal_(reset_gate[-1].weight)
- nn.init.orthogonal_(update_gate[-1].weight)
- nn.init.orthogonal_(out_gate[-1].weight)
- nn.init.constant_(reset_gate[-1].bias, -1.0)
- nn.init.constant_(update_gate[-1].bias, 0.0)
- nn.init.constant_(out_gate[-1].bias, 0.0)
-
- def forward(
- self,
- cell_input: torch.Tensor,
- previous_state: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
-
- Parameters
- ----------
- cell_input : torch.Tensor
- Reconstruction input
- previous_state : torch.Tensor
- Tensor of previous stats.
-
- Returns
- -------
- (torch.Tensor, torch.Tensor)
- """
-
- new_states: List[torch.Tensor] = []
- conv_skip: List[torch.Tensor] = []
-
- for idx in range(self.depth):
- if len(conv_skip) > 0:
- cell_input = F.relu(
- self.conv_blocks[idx](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1)),
- inplace=True,
- )
- else:
- cell_input = F.relu(self.conv_blocks[idx](cell_input), inplace=True)
- if self.dense_connect > 0:
- conv_skip.append(cell_input)
-
- stacked_inputs = torch.cat([cell_input, previous_state[:, :, :, :, idx]], dim=1)
-
- update = torch.sigmoid(self.update_gates[idx](stacked_inputs))
- reset = torch.sigmoid(self.reset_gates[idx](stacked_inputs))
- delta = torch.tanh(
- self.out_gates[idx](torch.cat([cell_input, previous_state[:, :, :, :, idx] * reset], dim=1))
- )
- cell_input = previous_state[:, :, :, :, idx] * (1 - update) + delta * update
- new_states.append(cell_input)
- cell_input = F.relu(cell_input, inplace=False)
- if len(conv_skip) > 0:
- out = self.conv_blocks[self.depth](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1))
- else:
- out = self.conv_blocks[self.depth](cell_input)
-
- return out, torch.stack(new_states, dim=-1)
-
-
-class MRILogLikelihood(nn.Module):
def __init__(
self,
forward_operator: Callable,
@@ -151,11 +33,8 @@ def __init__(
self.forward_operator = forward_operator
self.backward_operator = backward_operator
- # TODO UGLY
- self.ndim = 2
-
self._coil_dim = 1
- self._spatial_dims = (2, 3) if self.ndim == 2 else (2, 3, 4)
+ self._spatial_dims = (2, 3)
def forward(
self,
@@ -165,21 +44,17 @@ def forward(
sampling_mask,
loglikelihood_scaling=None,
) -> torch.Tensor:
- r"""
- Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils.
- $$ \frac{1}{\sigma^2} \sum_{i}^{\text{num coils}}
- {S}_i^\{text{H}} \mathcal{F}^{-1} P^T (P \mathcal{F} S_i x_\tau - y_\tau)$$
- for each time step $\tau$
+ """
Parameters
----------
input_image : torch.tensor
Initial or previous iteration of image with complex first
- of shape (batch, complex, [slice,] height, width).
+ of shape (N, complex, height, width).
masked_kspace : torch.tensor
- Masked k-space of shape (batch, coil, [slice,] height, width, complex).
+ Masked k-space of shape (N, coil, height, width, complex).
sensitivity_map : torch.tensor
- Sensitivity Map of shape (batch, coil, [slice,] height, width, complex).
+ Sensitivity Map of shape (N, coil, height, width, complex).
sampling_mask : torch.tensor
loglikelihood_scaling : torch.tensor
Multiplier for loglikelihood, for instance for the k-space noise, of shape (1,).
@@ -188,54 +63,57 @@ def forward(
-------
torch.Tensor
"""
- if input_image.ndim == 5:
- self.ndim = 3
- input_image = input_image.permute(
- (0, 2, 3, 1) if self.ndim == 2 else (0, 2, 3, 4, 1)
- ) # shape (batch, [slice,] height, width, complex)
+ input_image = input_image.permute(0, 2, 3, 1) # shape (N, height, width, complex)
+ if loglikelihood_scaling is not None:
+ loglikelihood_scaling = loglikelihood_scaling
+ else:
+ loglikelihood_scaling = torch.tensor([1.0], dtype=masked_kspace.dtype).to(masked_kspace.device)
loglikelihood_scaling = loglikelihood_scaling.reshape(
list(torch.ones(len(sensitivity_map.shape)).int())
- ) # shape (1, 1, 1, [1,] 1, 1)
+ ) # shape (1, 1, 1, 1, 1)
# We multiply by the loglikelihood_scaling here to prevent fp16 information loss,
# as this value is typically <<1, and the operators are linear.
mul = loglikelihood_scaling * T.complex_multiplication(
- sensitivity_map, input_image.unsqueeze(1) # (batch, 1, [slice,] height, width, complex)
- ) # shape (batch, coil, [slice,] height, width, complex)
+ sensitivity_map, input_image.unsqueeze(1) # (N, 1, height, width, complex)
+ ) # shape (N, coil, height, width, complex)
mr_forward = torch.where(
sampling_mask == 0,
torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
self.forward_operator(mul, dim=self._spatial_dims),
- ) # shape (batch, coil, [slice], height, width, complex)
+ ) # shape (N, coil, height, width, complex)
error = mr_forward - loglikelihood_scaling * torch.where(
sampling_mask == 0,
torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
masked_kspace,
- ) # shape (batch, coil, [slice], height, width, complex)
+ ) # shape (N, coil, height, width, complex)
- mr_backward = self.backward_operator(
- error, dim=self._spatial_dims
- ) # shape (batch, coil, [slice], height, width, complex)
+ mr_backward = self.backward_operator(error, dim=self._spatial_dims) # shape (N, coil, height, width, complex)
if sensitivity_map is not None:
out = T.complex_multiplication(T.conjugate(sensitivity_map), mr_backward).sum(self._coil_dim)
else:
out = mr_backward.sum(self._coil_dim)
- # out has shape (batch, complex=2, [slice], height, width)
+ # out has shape (N, complex=2, height, width)
- out = (
- out.permute(0, 3, 1, 2) if self.ndim == 2 else out.permute(0, 4, 1, 2, 3)
- ) # complex first: shape (batch, [slice], height, width, complex=2)
+ out = out.permute(0, 3, 1, 2) # complex first: shape (N, height, width, complex=2)
return out
class RIMInit(nn.Module):
+ """
+ Learned initializer for RIM, based on multi-scale context aggregation with dilated convolutions, that replaces
+ zero initializer for the RIM hidden vector.
+
+ Inspired by "Multi-Scale Context Aggregation by Dilated Convolutions" (https://arxiv.org/abs/1511.07122)
+ """
+
def __init__(
self,
x_ch: int,
@@ -246,10 +124,6 @@ def __init__(
multiscale_depth: int = 1,
):
"""
- Learned initializer for RIM, based on multi-scale context aggregation with dilated convolutions, that replaces
- zero initializer for the RIM hidden vector.
-
- Inspired by "Multi-Scale Context Aggregation by Dilated Convolutions" (https://arxiv.org/abs/1511.07122)
Parameters
----------
@@ -268,6 +142,7 @@ def __init__(
"""
super().__init__()
+
self.conv_blocks = nn.ModuleList()
self.out_blocks = nn.ModuleList()
self.depth = depth
@@ -373,10 +248,11 @@ def __init__(
self.no_parameter_sharing = no_parameter_sharing
for _ in range(length if no_parameter_sharing else 1):
self.cell_list.append(
- ConvGRUCell(
- x_channels,
- hidden_channels,
- depth=depth,
+ Conv2dGRU(
+ in_channels=x_channels * 2, # double channels as input is concatenated image and gradient
+ out_channels=x_channels,
+ hidden_channels=hidden_channels,
+ num_layers=depth,
instance_norm=instance_norm,
dense_connect=dense_connect,
replication_padding=replication_padding,
@@ -386,18 +262,21 @@ def __init__(
self.length = length
self.depth = depth
- def compute_sense_init(self, kspace, sensitivity_map, spatial_dims=(2, 3), coil_dim=1):
- # kspace is of shape: (batch, coil, [slice,] height, width, complex)
- # sensitivity_map is of shape (batch, coil, [slice,] height, width, complex)
+ self._coil_dim = 1
+ self._spatial_dims = (2, 3)
+
+ def compute_sense_init(self, kspace, sensitivity_map):
+ # kspace is of shape: (N, coil, height, width, complex)
+ # sensitivity_map is of shape (N, coil, height, width, complex)
input_image = T.complex_multiplication(
T.conjugate(sensitivity_map),
- self.backward_operator(kspace, dim=spatial_dims),
- ) # shape (batch, coil, [slice,] height, width, complex=2)
+ self.backward_operator(kspace, dim=self._spatial_dims),
+ ) # shape (N, coil, height, width, complex=2)
- input_image = input_image.sum(coil_dim)
+ input_image = input_image.sum(self._coil_dim)
- # shape (batch, [slice,] height, width, complex=2)
+ # shape (N, height, width, complex=2)
return input_image
def forward(
@@ -414,13 +293,13 @@ def forward(
Parameters
----------
input_image : torch.Tensor
- Initial or intermediate guess of input. Has shape (batch, [slice,] height, width, complex=2).
+ Initial or intermediate guess of input. Has shape (N, height, width, complex=2).
masked_kspace : torch.Tensor
- Kspace masked by the sampling mask.
+ Masked k-space of shape (N, coil, height, width, complex=2).
sensitivity_map : torch.Tensor
- Coil sensitivities.
+ Sensitivity map of shape (N, coil, height, width, complex=2).
sampling_mask : torch.Tensor
- Sampling mask.
+ Sampling mask of shape (N, 1, height, width, 1).
previous_state : torch.Tensor
loglikelihood_scaling : torch.Tensor
Float tensor of shape (1,).
@@ -429,13 +308,11 @@ def forward(
-------
torch.Tensor
"""
-
if input_image is None:
if self.image_initialization == "sense":
input_image = self.compute_sense_init(
kspace=masked_kspace,
sensitivity_map=sensitivity_map,
- spatial_dims=(3, 4) if masked_kspace.ndim == 6 else (2, 3),
)
elif self.image_initialization == "input_kspace":
if "initial_kspace" not in kwargs:
@@ -445,7 +322,6 @@ def forward(
input_image = self.compute_sense_init(
kspace=kwargs["initial_kspace"],
sensitivity_map=sensitivity_map,
- spatial_dims=(3, 4) if kwargs["initial_kspace"].ndim == 6 else (2, 3),
)
elif self.image_initialization == "input_image":
if "initial_image" not in kwargs:
@@ -462,34 +338,25 @@ def forward(
f"Unknown image_initialization. Expected `sense`, `input_kspace`, `'input_image` or `zero_filled`. "
f"Got {self.image_initialization}."
)
-
# Provide an initialization for the first hidden state.
if (self.initializer is not None) and (previous_state is None):
previous_state = self.initializer(
- input_image.permute((0, 4, 1, 2, 3) if input_image.ndim == 5 else (0, 3, 1, 2))
- ) # permute to (batch, complex, [slice], height, width),
+ input_image.permute(0, 3, 1, 2)
+ ) # permute to (N, complex, height, width),
# TODO: This has to be made contiguous
- # TODO(gy): Do 3D data pass from here? If not remove if statements below and [slice,] from comments.
- input_image = input_image.permute(
- (0, 4, 1, 2, 3) if input_image.ndim == 5 else (0, 3, 1, 2)
- ).contiguous() # shape (batch, , complex=2, [slice,] height, width)
+ input_image = input_image.permute(0, 3, 1, 2).contiguous() # shape (N, complex=2, height, width)
batch_size = input_image.size(0)
- spatial_shape = (
- [input_image.size(-3), input_image.size(-2), input_image.size(-1)]
- if input_image.ndim == 5
- else [input_image.size(-2), input_image.size(-1)]
- )
-
+ spatial_shape = [input_image.size(self._spatial_dims[0]), input_image.size(self._spatial_dims[1])]
# Initialize zero state for RIM
state_size = [batch_size, self.hidden_channels] + list(spatial_shape) + [self.depth]
if previous_state is None:
- # shape (batch, hidden_channels, [slice,] height, width, depth)
+ # shape (N, hidden_channels, height, width, depth)
previous_state = torch.zeros(*state_size, dtype=input_image.dtype).to(input_image.device)
cell_outputs = []
- intermediate_image = input_image # shape (batch, , complex=2, [slice,] height, width)
+ intermediate_image = input_image # shape (N, complex=2, height, width)
for cell_idx in range(self.length):
cell = self.cell_list[cell_idx] if self.no_parameter_sharing else self.cell_list[0]
@@ -500,7 +367,7 @@ def forward(
sensitivity_map,
sampling_mask,
loglikelihood_scaling,
- ) # shape (batch, , complex=2, [slice,] height, width)
+ ) # shape (N, complex=2, height, width)
if grad_loglikelihood.abs().max() > 150.0:
warnings.warn(
@@ -511,16 +378,16 @@ def forward(
cell_input = torch.cat(
[intermediate_image, grad_loglikelihood],
dim=1,
- ) # shape (batch, , complex=4, [slice,] height, width)
+ ) # shape (N, complex=4, height, width)
cell_output, previous_state = cell(cell_input, previous_state)
- # shapes (batch, complex=2, [slice,] height, width), (batch, hidden_channels, [slice,] height, width, depth)
+ # shapes (N, complex=2, height, width), (N, hidden_channels, height, width, depth)
if self.skip_connections:
- # shape (batch, complex=2, [slice,] height, width)
+ # shape (N, complex=2, height, width)
intermediate_image = intermediate_image + cell_output
else:
- # shape (batch, complex=2, [slice,] height, width)
+ # shape (N, complex=2, height, width)
intermediate_image = cell_output
if not self.training:
diff --git a/direct/nn/rim/rim_engine.py b/direct/nn/rim/rim_engine.py
index d30a70e9..522d24be 100644
--- a/direct/nn/rim/rim_engine.py
+++ b/direct/nn/rim/rim_engine.py
@@ -232,7 +232,7 @@ def ssim_loss(source, reduction="mean", **data):
source_abs, target_abs = self.cropper(source, data["target"], resolution)
data_range = torch.tensor([target_abs.max()], device=target_abs.device)
- ssim_loss = SSIMLoss().to(source_abs.device)(source_abs, target_abs, data_range=data_range)
+ ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range)
return ssim_loss
diff --git a/direct/nn/rim/tests/test_mri_models.py b/direct/nn/rim/tests/test_mri_models.py
deleted file mode 100644
index 7301dc3a..00000000
--- a/direct/nn/rim/tests/test_mri_models.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# coding=utf-8
-# Copyright (c) DIRECT Contributors
-
-import warnings
-
-import numpy as np
-import torch
-
-from direct.data import transforms as T
-from direct.data.transforms import tensor_to_complex_numpy
-
-warnings.filterwarnings("ignore")
-
-
-def numpy_fft(data, dims=(-2, -1)):
- """
- Fast Fourier Transform.
- """
- data = np.fft.ifftshift(data, dims)
- out = np.fft.fft2(data, norm="ortho")
- out = np.fft.fftshift(out, dims)
-
- return out
-
-
-def numpy_ifft(data, dims=(-2, -1)):
- """
- Inverse Fast Fourier Transform.
- """
- data = np.fft.ifftshift(data, dims)
- out = np.fft.ifft2(data, norm="ortho")
- out = np.fft.fftshift(out, dims)
-
- return out
-
-
-def create_input(shape):
- data = np.arange(np.product(shape)).reshape(shape).copy()
- data = torch.from_numpy(data).float()
-
- return data
-
-
-batch, coil, height, width, complex = 3, 15, 100, 80, 2
-
-input_image = create_input([batch, height, width, complex])
-sensitivity_map = create_input([batch, coil, height, width, complex]) * 0.1
-masked_kspace = create_input([batch, coil, height, width, complex]) + 0.33
-sampling_mask = torch.from_numpy(np.random.binomial(size=(batch, 1, height, width, 1), n=1, p=0.5))
-
-input_image_numpy = tensor_to_complex_numpy(input_image)
-sensitivity_map_numpy = tensor_to_complex_numpy(sensitivity_map)
-masked_kspace_numpy = tensor_to_complex_numpy(masked_kspace)
-sampling_mask_numpy = sampling_mask.numpy()[..., 0]
-
-mul = T.complex_multiplication(sensitivity_map, input_image.unsqueeze(1))
-
-dims = (2, 3)
-mr_forward = torch.where(
- sampling_mask == 0,
- torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
- T.fft2(mul, dim=dims),
-)
-
-error = mr_forward - torch.where(
- sampling_mask == 0,
- torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
- masked_kspace,
-)
-
-mr_backward = T.ifft2(error, dim=dims)
-
-coil_dim = 1
-out = T.complex_multiplication(T.conjugate(sensitivity_map), mr_backward).sum(coil_dim)
-
-
-# numpy
-mul_numpy = sensitivity_map_numpy * input_image_numpy.reshape(batch, 1, height, width)
-mr_forward_numpy = sampling_mask_numpy * numpy_fft(mul_numpy)
-error_numpy = mr_forward_numpy - sampling_mask_numpy * masked_kspace_numpy
-mr_backward_numpy = numpy_ifft(error_numpy)
-out_numpy = (sensitivity_map_numpy.conjugate() * mr_backward_numpy).sum(1)
-
-np.allclose(tensor_to_complex_numpy(out), out_numpy)
-
-# numpy 2
-mr_backward_numpy = numpy_ifft(
- sampling_mask_numpy * numpy_fft(sensitivity_map_numpy * input_image_numpy[:, np.newaxis, ...])
- - sampling_mask_numpy * masked_kspace_numpy
-)
-out_numpy = (sensitivity_map_numpy.conjugate() * mr_backward_numpy).sum(1)
-
-np.allclose(tensor_to_complex_numpy(out), out_numpy)
diff --git a/direct/nn/rim/tests/test_rim.py b/direct/nn/rim/tests/test_rim.py
new file mode 100644
index 00000000..3e5cfba8
--- /dev/null
+++ b/direct/nn/rim/tests/test_rim.py
@@ -0,0 +1,92 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import pytest
+import torch
+
+from direct.data.transforms import fft2, ifft2
+from direct.nn.rim.rim import RIM
+
+
+def create_input(shape):
+
+ data = torch.rand(shape).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [3, 3, 16, 16],
+ [2, 5, 16, 32],
+ ],
+)
+@pytest.mark.parametrize(
+ "hidden_channels",
+ [4, 8],
+)
+@pytest.mark.parametrize(
+ "length",
+ [3],
+)
+@pytest.mark.parametrize(
+ "depth",
+ [1, 2],
+)
+@pytest.mark.parametrize(
+ "no_parameter_sharing",
+ [True, False],
+)
+@pytest.mark.parametrize(
+ "instance_norm",
+ [True, False],
+)
+@pytest.mark.parametrize(
+ "dense_connect",
+ [True, False],
+)
+@pytest.mark.parametrize(
+ "skip_connections",
+ [True, False],
+)
+@pytest.mark.parametrize(
+ "image_init",
+ [
+ "zero-filled",
+ "sense",
+ "input-kspace",
+ ],
+)
+def test_rim(
+ shape,
+ hidden_channels,
+ length,
+ depth,
+ no_parameter_sharing,
+ instance_norm,
+ dense_connect,
+ skip_connections,
+ image_init,
+):
+ model = RIM(
+ fft2,
+ ifft2,
+ hidden_channels=hidden_channels,
+ length=length,
+ depth=depth,
+ no_parameter_sharing=no_parameter_sharing,
+ instance_norm=instance_norm,
+ dense_connect=dense_connect,
+ skip_connections=skip_connections,
+ image_initialization=image_init,
+ ).cpu()
+
+ img = create_input([shape[0]] + shape[2:] + [2]).cpu()
+ kspace = create_input(shape + [2]).cpu()
+ sens = create_input(shape + [2]).cpu()
+ mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu()
+
+ out = model(img, kspace, mask, sens)[0][-1]
+
+ assert list(out.shape) == [shape[0]] + [2] + shape[2:]
diff --git a/direct/nn/unet/config.py b/direct/nn/unet/config.py
index 4e3f7683..aa3e46d8 100644
--- a/direct/nn/unet/config.py
+++ b/direct/nn/unet/config.py
@@ -12,3 +12,13 @@ class UnetModel2dConfig(ModelConfig):
num_filters: int = 16
num_pool_layers: int = 4
dropout_probability: float = 0.0
+
+
+@dataclass
+class Unet2dConfig(ModelConfig):
+ num_filters: int = 16
+ num_pool_layers: int = 4
+ dropout_probability: float = 0.0
+ skip_connection: bool = False
+ normalized: bool = False
+ image_initialization: str = "zero_filled"
diff --git a/direct/nn/unet/tests/__init__.py b/direct/nn/unet/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/unet/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/unet/tests/test_unet_2d.py b/direct/nn/unet/tests/test_unet_2d.py
new file mode 100644
index 00000000..bf912643
--- /dev/null
+++ b/direct/nn/unet/tests/test_unet_2d.py
@@ -0,0 +1,61 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+
+import numpy as np
+import pytest
+import torch
+
+from direct.data.transforms import fft2, ifft2
+from direct.nn.unet.unet_2d import Unet2d, NormUnetModel2d
+
+
+def create_input(shape):
+ data = np.random.randn(*shape).copy()
+ data = torch.from_numpy(data).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [2, 3, 16, 16],
+ [4, 5, 16, 32],
+ [3, 4, 32, 32],
+ [3, 4, 40, 20],
+ ],
+)
+@pytest.mark.parametrize(
+ "num_filters",
+ [4, 6, 8],
+)
+@pytest.mark.parametrize(
+ "num_pool_layers",
+ [2, 3],
+)
+@pytest.mark.parametrize(
+ "skip",
+ [True, False],
+)
+@pytest.mark.parametrize(
+ "normalized",
+ [True, False],
+)
+def test_unet_2d(shape, num_filters, num_pool_layers, skip, normalized):
+ model = Unet2d(
+ fft2,
+ ifft2,
+ num_filters=num_filters,
+ num_pool_layers=num_pool_layers,
+ skip_connection=skip,
+ normalized=normalized,
+ dropout_probability=0.05,
+ ).cpu()
+
+ data = create_input(shape + [2]).cpu()
+ sens = create_input(shape + [2]).cpu()
+
+ out = model(data, sens)
+
+ assert list(out.shape) == [shape[0]] + shape[2:] + [2]
diff --git a/direct/nn/unet/unet_2d.py b/direct/nn/unet/unet_2d.py
index 859c1916..18d27c88 100644
--- a/direct/nn/unet/unet_2d.py
+++ b/direct/nn/unet/unet_2d.py
@@ -1,12 +1,16 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
-# Code borrowed / edited from: https://github.com/facebookresearch/fastMRI/blob/master/models/unet/unet_model.py
+# Code borrowed / edited from: https://github.com/facebookresearch/fastMRI/blob/
+import math
+from typing import Callable, List, Optional, Tuple
import torch
from torch import nn
from torch.nn import functional as F
+from direct.data import transforms as T
+
class ConvBlock(nn.Module):
"""
@@ -156,7 +160,7 @@ def __init__(
self.up_conv = nn.ModuleList()
self.up_transpose_conv = nn.ModuleList()
- for i in range(num_pool_layers - 1):
+ for _ in range(num_pool_layers - 1):
self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)]
self.up_conv += [ConvBlock(ch * 2, ch, dropout_probability)]
ch //= 2
@@ -209,3 +213,228 @@ def forward(self, input: torch.Tensor):
output = conv(output)
return output
+
+
+class NormUnetModel2d(nn.Module):
+ """
+ Implementation of a Normalized U-Net model.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_filters: int,
+ num_pool_layers: int,
+ dropout_probability: float,
+ norm_groups: int = 2,
+ ):
+ """
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels to the u-net.
+ out_channels : int
+ Number of output channels to the u-net.
+ num_filters : int
+ Number of output channels of the first convolutional layer.
+ num_pool_layers : int
+ Number of down-sampling and up-sampling layers (depth).
+ dropout_probability : float
+ Dropout probability.
+ norm_groups : int,
+ Number of normalization groups.
+ """
+ super().__init__()
+
+ self.unet2d = UnetModel2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_filters=num_filters,
+ num_pool_layers=num_pool_layers,
+ dropout_probability=dropout_probability,
+ )
+
+ self.norm_groups = norm_groups
+
+ @staticmethod
+ def norm(input: torch.Tensor, groups: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # group norm
+ b, c, h, w = input.shape
+ input = input.reshape(b, groups, -1)
+
+ mean = input.mean(-1, keepdim=True)
+ std = input.std(-1, keepdim=True)
+
+ output = (input - mean) / std
+ output = output.reshape(b, c, h, w)
+
+ return output, mean, std
+
+ @staticmethod
+ def unnorm(input: torch.Tensor, mean: torch.Tensor, std: torch.Tensor, groups: int) -> torch.Tensor:
+ b, c, h, w = input.shape
+ input = input.reshape(b, groups, -1)
+ return (input * std + mean).reshape(b, c, h, w)
+
+ @staticmethod
+ def pad(input: torch.Tensor) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
+ _, _, h, w = input.shape
+ w_mult = ((w - 1) | 15) + 1
+ h_mult = ((h - 1) | 15) + 1
+ w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
+ h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
+
+ output = F.pad(input, w_pad + h_pad)
+ return output, (h_pad, w_pad, h_mult, w_mult)
+
+ @staticmethod
+ def unpad(
+ input: torch.Tensor,
+ h_pad: List[int],
+ w_pad: List[int],
+ h_mult: int,
+ w_mult: int,
+ ) -> torch.Tensor:
+
+ return input[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ """
+
+ Parameters
+ ----------
+ input : torch.Tensor
+
+ Returns
+ -------
+ torch.Tensor
+ """
+
+ output, mean, std = self.norm(input, self.norm_groups)
+ output, pad_sizes = self.pad(output)
+ output = self.unet2d(output)
+
+ output = self.unpad(output, *pad_sizes)
+ output = self.unnorm(output, mean, std, self.norm_groups)
+
+ return output
+
+
+class Unet2d(nn.Module):
+ """
+ PyTorch implementation of a U-Net model for MRI Reconstruction.
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable,
+ backward_operator: Callable,
+ num_filters: int,
+ num_pool_layers: int,
+ dropout_probability: float,
+ skip_connection: bool = False,
+ normalized: bool = False,
+ image_initialization: str = "zero_filled",
+ **kwargs,
+ ):
+ """
+
+ Parameters
+ ----------
+ forward_operator : Callable
+ Forward Operator.
+ backward_operator : Callable
+ Backward Operator.
+ num_filters : int
+ Number of first layer filters.
+ num_pool_layers : int
+ Number of pooling layers.
+ dropout_probability : float
+ Dropout probability.
+ skip_connection : bool
+ If True, skip connection is used for the output. Default: False.
+ normalized : bool
+ If True, Normalized Unet is used. Default: False.
+ image_initialization : str
+ Type of image initialization. Default: "zero-filled".
+ kwargs: dict
+ """
+ super().__init__()
+ extra_keys = kwargs.keys()
+ for extra_key in extra_keys:
+ if extra_key not in [
+ "sensitivity_map_model",
+ "model_name",
+ ]:
+ raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
+ if normalized:
+ self.unet = NormUnetModel2d(
+ in_channels=2,
+ out_channels=2,
+ num_filters=num_filters,
+ num_pool_layers=num_pool_layers,
+ dropout_probability=dropout_probability,
+ )
+ else:
+ self.unet = UnetModel2d(
+ in_channels=2,
+ out_channels=2,
+ num_filters=num_filters,
+ num_pool_layers=num_pool_layers,
+ dropout_probability=dropout_probability,
+ )
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+ self.skip_connection = skip_connection
+ self.image_initialization = image_initialization
+ self._coil_dim = 1
+ self._spatial_dims = (2, 3)
+
+ def compute_sense_init(self, kspace, sensitivity_map):
+ input_image = T.complex_multiplication(
+ T.conjugate(sensitivity_map),
+ self.backward_operator(kspace, dim=self._spatial_dims),
+ )
+ input_image = input_image.sum(self._coil_dim)
+ return input_image
+
+ def forward(
+ self,
+ masked_kspace: torch.Tensor,
+ sensitivity_map: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+
+ Parameters
+ ----------
+ masked_kspace : torch.Tensor
+ Masked k-space of shape (N, coil, height, width, complex=2).
+ sensitivity_map : torch.Tensor
+ Sensitivity map of shape (N, coil, height, width, complex=2). Default: None.
+
+ Returns
+ -------
+ torch.Tensor
+ Output image of shape (N, height, width, complex=2).
+ """
+ if self.image_initialization == "sense":
+ if sensitivity_map is None:
+ raise ValueError("Expected sensitivity_map not to be None with 'sense' image_initialization.")
+ input_image = self.compute_sense_init(
+ kspace=masked_kspace,
+ sensitivity_map=sensitivity_map,
+ )
+ elif self.image_initialization == "zero_filled":
+ input_image = self.backward_operator(masked_kspace).sum(self._coil_dim)
+ else:
+ raise ValueError(
+ f"Unknown image_initialization. Expected `sense` or `zero_filled`. "
+ f"Got {self.image_initialization}."
+ )
+
+ output = self.unet(input_image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+ if self.skip_connection:
+ output += input_image
+ return output
diff --git a/direct/nn/unet/unet_engine.py b/direct/nn/unet/unet_engine.py
new file mode 100644
index 00000000..2a4b7ef8
--- /dev/null
+++ b/direct/nn/unet/unet_engine.py
@@ -0,0 +1,469 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import time
+from collections import defaultdict
+from os import PathLike
+from typing import Callable, DefaultDict, Dict, List, Optional
+
+import numpy as np
+import torch
+from torch import nn
+from torch.cuda.amp import autocast
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+
+import direct.data.transforms as T
+from direct.config import BaseConfig
+from direct.engine import DoIterationOutput, Engine
+from direct.functionals import SSIMLoss
+from direct.utils import (
+ communication,
+ detach_dict,
+ dict_to_device,
+ merge_list_of_dicts,
+ multiply_function,
+ reduce_list_of_dicts,
+)
+from direct.utils.communication import reduce_tensor_dict
+
+
+class Unet2dEngine(Engine):
+ """
+ Unet2d Model Engine.
+ """
+
+ def __init__(
+ self,
+ cfg: BaseConfig,
+ model: nn.Module,
+ device: int,
+ forward_operator: Optional[Callable] = None,
+ backward_operator: Optional[Callable] = None,
+ mixed_precision: bool = False,
+ **models: nn.Module,
+ ):
+ super().__init__(
+ cfg,
+ model,
+ device,
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ mixed_precision=mixed_precision,
+ **models,
+ )
+
+ self._complex_dim = -1
+ self._coil_dim = 1
+ self._spatial_dims = (2, 3)
+
+ def _do_iteration(
+ self,
+ data: Dict[str, torch.Tensor],
+ loss_fns: Optional[Dict[str, Callable]] = None,
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ ) -> DoIterationOutput:
+
+ # loss_fns can be done, e.g. during validation
+ if loss_fns is None:
+ loss_fns = {}
+
+ if regularizer_fns is None:
+ regularizer_fns = {}
+
+ loss_dicts = []
+ regularizer_dicts = []
+
+ data = dict_to_device(data, self.device)
+
+ if self.cfg.model.image_initialization == "sense":
+ # sensitivity_map of shape (batch, coil, height, width, complex=2)
+ sensitivity_map = data["sensitivity_map"]
+
+ # Some things can be done with the sensitivity map here, e.g. apply a u-net
+ if "sensitivity_model" in self.models:
+ # Move channels to first axis
+ sensitivity_map = data["sensitivity_map"].permute(
+ (0, 1, 4, 2, 3)
+ ) # shape (batch, coil, complex=2, height, width)
+
+ sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute(
+ (0, 1, 3, 4, 2)
+ ) # has channel last: shape (batch, coil, height, width, complex=2)
+
+ # The sensitivity map needs to be normalized such that
+ # So \sum_{i \in \text{coils}} S_i S_i^* = 1
+
+ sensitivity_map_norm = torch.sqrt(((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim))
+ # shape (batch, 1, height, width, 1)
+ sensitivity_map_norm = sensitivity_map_norm.unsqueeze(self._coil_dim).unsqueeze(self._complex_dim)
+ data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm)
+
+ with autocast(enabled=self.mixed_precision):
+
+ output_image = self.model(
+ masked_kspace=data["masked_kspace"],
+ sensitivity_map=data["sensitivity_map"] if self.cfg.model.image_initialization == "sense" else None,
+ )
+ output_image = T.modulus(output_image)
+
+ loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()}
+ regularizer_dict = {
+ k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys()
+ }
+
+ for key, value in loss_dict.items():
+ loss_dict[key] = value + loss_fns[key](
+ output_image,
+ **data,
+ reduction="mean",
+ )
+
+ for key, value in regularizer_dict.items():
+ regularizer_dict[key] = value + regularizer_fns[key](
+ output_image,
+ **data,
+ )
+
+ loss = sum(loss_dict.values()) + sum(regularizer_dict.values())
+
+ if self.model.training:
+ self._scaler.scale(loss).backward()
+
+ loss_dicts.append(detach_dict(loss_dict))
+ regularizer_dicts.append(
+ detach_dict(regularizer_dict)
+ ) # Need to detach dict as this is only used for logging.
+
+ # Add the loss dicts.
+ loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum")
+ regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum")
+
+ return DoIterationOutput(
+ output_image=output_image,
+ sensitivity_map=data["sensitivity_map"],
+ data_dict={**loss_dict, **regularizer_dict},
+ )
+
+ def build_loss(self, **kwargs) -> Dict:
+ # TODO: Cropper is a processing output tool.
+ def get_resolution(**data):
+ """Be careful that this will use the cropping size of the FIRST sample in the batch."""
+ return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None))
+
+ def l1_loss(source, reduction="mean", **data):
+ """
+ Calculate L1 loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l1_loss
+
+ def l2_loss(source, reduction="mean", **data):
+ """
+ Calculate L2 loss (MSE) given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l2_loss
+
+ def ssim_loss(source, reduction="mean", **data):
+ """
+ Calculate SSIM loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ if reduction != "mean":
+ raise AssertionError(
+ f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}."
+ )
+
+ source_abs, target_abs = self.cropper(source, data["target"], resolution)
+ data_range = torch.tensor([target_abs.max()], device=target_abs.device)
+
+ ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range)
+
+ return ssim_loss
+
+ # Build losses
+ loss_dict = {}
+ for curr_loss in self.cfg.training.loss.losses: # type: ignore
+ loss_fn = curr_loss.function
+ if loss_fn == "l1_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss)
+ elif loss_fn == "l2_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss)
+ elif loss_fn == "ssim_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss)
+ else:
+ raise ValueError(f"{loss_fn} not permissible.")
+
+ return loss_dict
+
+ @torch.no_grad()
+ def evaluate(
+ self,
+ data_loader: DataLoader,
+ loss_fns: Optional[Dict[str, Callable]],
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ crop: Optional[str] = None,
+ is_validation_process: bool = True,
+ ):
+ """
+ Validation process. Assumes that each batch only contains slices of the same volume *AND* that these
+ are sequentially ordered.
+
+ Parameters
+ ----------
+ data_loader : DataLoader
+ loss_fns : Dict[str, Callable], optional
+ regularizer_fns : Dict[str, Callable], optional
+ crop : str, optional
+ is_validation_process : bool
+
+ Returns
+ -------
+ loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ """
+ self.models_to_device()
+ self.models_validation_mode()
+ torch.cuda.empty_cache()
+
+ # Variables required for evaluation.
+ volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore
+
+ # filenames can be in the volume_indices attribute of the dataset
+ num_for_this_process = None
+ all_filenames = None
+ if hasattr(data_loader.dataset, "volume_indices"):
+ all_filenames = list(data_loader.dataset.volume_indices.keys())
+ num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys()))
+ self.logger.info(
+ f"Reconstructing a total of {len(all_filenames)} volumes. "
+ f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})."
+ )
+
+ filenames_seen = 0
+ reconstruction_output: DefaultDict = defaultdict(list)
+ if is_validation_process:
+ targets_output: DefaultDict = defaultdict(list)
+ val_losses = []
+ val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict)
+ last_filename = None
+
+ # Container to for the slices which can be visualized in TensorBoard.
+ visualize_slices: List[np.ndarray] = []
+ visualize_target: List[np.ndarray] = []
+ # visualizations = {}
+
+ extra_visualization_keys = (
+ self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore
+ )
+
+ # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
+ # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
+ # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only
+ # contains data from one volume.
+ time_start = time.time()
+
+ for iter_idx, data in enumerate(data_loader):
+ filenames = data.pop("filename")
+ if len(set(filenames)) != 1:
+ raise ValueError(
+ f"Expected a batch during validation to only contain filenames of one case. "
+ f"Got {set(filenames)}."
+ )
+
+ slice_nos = data.pop("slice_no")
+ scaling_factors = data["scaling_factor"]
+
+ resolution = self.compute_resolution(
+ key=self.cfg.validation.crop, # type: ignore
+ reconstruction_size=data.get("reconstruction_size", None),
+ )
+
+ # Compute output and loss.
+ iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns)
+ output = iteration_output.output_image
+ loss_dict = iteration_output.data_dict
+
+ loss_dict = detach_dict(loss_dict)
+ output = output.detach()
+ val_losses.append(loss_dict)
+
+ # Output is complex-valued, and has to be cropped. This holds for both output and target.
+ # Output has shape (batch, complex, height, width)
+ output_abs = self.process_output(
+ output,
+ scaling_factors,
+ resolution=resolution,
+ )
+
+ if is_validation_process:
+ # Target has shape (batch, height, width)
+ target_abs = self.process_output(
+ data["target"].detach(),
+ scaling_factors,
+ resolution=resolution,
+ )
+ for key in extra_visualization_keys:
+ curr_data = data[key].detach()
+ # Here we need to discover which keys are actually normalized or not
+ # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23
+
+ del output # Explicitly call delete to clear memory.
+
+ # Aggregate volumes to be able to compute the metrics on complete volumes.
+ for idx, filename in enumerate(filenames):
+ if last_filename is None:
+ last_filename = filename # First iteration last_filename is not set.
+
+ curr_slice = output_abs[idx].detach()
+ slice_no = int(slice_nos[idx].numpy())
+
+ reconstruction_output[filename].append((slice_no, curr_slice.cpu()))
+
+ if is_validation_process:
+ targets_output[filename].append((slice_no, target_abs[idx].cpu()))
+
+ is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"])
+ reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch]
+ for condition in reconstruction_conditions:
+ if condition:
+ filenames_seen += 1
+
+ # Now we can ditch the reconstruction dict by reconstructing the volume,
+ # will take too much memory otherwise.
+ volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]])
+ if is_validation_process:
+ target = torch.stack([_[1] for _ in targets_output[last_filename]])
+ curr_metrics = {
+ metric_name: metric_fn(target, volume)
+ for metric_name, metric_fn in volume_metrics.items()
+ }
+ val_volume_metrics[last_filename] = curr_metrics
+ # Log the center slice of the volume
+ if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore
+ visualize_slices.append(volume[volume.shape[0] // 2])
+ visualize_target.append(target[target.shape[0] // 2])
+
+ # Delete outputs from memory, and recreate dictionary.
+ # This is not needed when not in validation as we are actually interested
+ # in the iteration output.
+ del targets_output[last_filename]
+ del reconstruction_output[last_filename]
+
+ if all_filenames:
+ log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
+ else:
+ log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"
+
+ self.logger.info(
+ f"{log_prefix} {last_filename}"
+ f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
+ )
+ # restart timer
+ time_start = time.time()
+ last_filename = filename
+
+ # Average loss dict
+ loss_dict = reduce_list_of_dicts(val_losses)
+ reduce_tensor_dict(loss_dict)
+
+ communication.synchronize()
+ torch.cuda.empty_cache()
+
+ all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics))
+ if not is_validation_process:
+ return loss_dict, reconstruction_output
+
+ return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ def process_output(self, data, scaling_factors=None, resolution=None):
+ # data is of shape (batch, complex=2, height, width)
+ if scaling_factors is not None:
+ data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device)
+
+ data = T.modulus_if_complex(data)
+
+ if len(data.shape) == 3: # (batch, height, width)
+ data = data.unsqueeze(1) # Added channel dimension.
+
+ if resolution is not None:
+ data = T.center_crop(data, resolution).contiguous()
+
+ return data
+
+ @staticmethod
+ def compute_resolution(key, reconstruction_size):
+ if key == "header":
+ # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over
+ # batches.
+ resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size]
+ # The volume sampler should give validation indices belonging to the *same* volume, so it should be
+ # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y).
+ resolution = [_[0] for _ in resolution][:-1]
+ elif key == "training":
+ resolution = key
+ elif not key:
+ resolution = None
+ else:
+ raise ValueError(
+ "Cropping should be either set to `header` to get the values from the header or "
+ "`training` to take the same value as training."
+ )
+ return resolution
+
+ def cropper(self, source, target, resolution):
+ """
+ 2D source/target cropper
+
+ Parameters:
+ -----------
+ Source has shape (batch, height, width)
+ Target has shape (batch, height, width)
+
+ """
+
+ if not resolution or all(_ == 0 for _ in resolution):
+ return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension.
+
+ source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension.
+ target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension.
+
+ return source_abs, target_abs
+
+ def compute_model_per_coil(self, model_name, data):
+ """
+ Computes model per coil.
+ """
+ # data is of shape (batch, coil, complex=2, height, width)
+ output = []
+
+ for idx in range(data.size(self._coil_dim)):
+ subselected_data = data.select(self._coil_dim, idx)
+ output.append(self.models[model_name](subselected_data))
+ output = torch.stack(output, dim=self._coil_dim)
+
+ # output is of shape (batch, coil, complex=2, height, width)
+ return output
diff --git a/direct/nn/varnet/__init__.py b/direct/nn/varnet/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/varnet/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/varnet/config.py b/direct/nn/varnet/config.py
new file mode 100644
index 00000000..14cc5e75
--- /dev/null
+++ b/direct/nn/varnet/config.py
@@ -0,0 +1,13 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+from dataclasses import dataclass
+
+from direct.config.defaults import ModelConfig
+
+
+@dataclass
+class EndToEndVarNetConfig(ModelConfig):
+ num_layers: int = 8
+ regularizer_num_filters: int = 18
+ regularizer_num_pull_layers: int = 4
+ regularizer_dropout: float = 0.0
diff --git a/direct/nn/varnet/tests/__init__.py b/direct/nn/varnet/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/varnet/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/varnet/tests/test_varnet.py b/direct/nn/varnet/tests/test_varnet.py
new file mode 100644
index 00000000..67cbd59f
--- /dev/null
+++ b/direct/nn/varnet/tests/test_varnet.py
@@ -0,0 +1,43 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import pytest
+import torch
+
+from direct.data.transforms import fft2, ifft2
+from direct.nn.varnet.varnet import EndToEndVarNet
+
+
+def create_input(shape):
+
+ data = torch.rand(shape).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [[4, 3, 32, 32], [4, 5, 40, 20]],
+)
+@pytest.mark.parametrize(
+ "num_layers",
+ [2, 3, 6],
+)
+@pytest.mark.parametrize(
+ "num_filters",
+ [2, 4],
+)
+@pytest.mark.parametrize(
+ "num_pull_layers",
+ [2, 4],
+)
+def test_varnet(shape, num_layers, num_filters, num_pull_layers):
+ model = EndToEndVarNet(fft2, ifft2, num_layers, num_filters, num_pull_layers, in_channels=2).cpu()
+
+ kspace = create_input(shape + [2]).cpu()
+ mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu()
+ sens = create_input(shape + [2]).cpu()
+
+ out = model(kspace, mask, sens)
+
+ assert list(out.shape) == shape + [2]
diff --git a/direct/nn/varnet/varnet.py b/direct/nn/varnet/varnet.py
new file mode 100644
index 00000000..00ade925
--- /dev/null
+++ b/direct/nn/varnet/varnet.py
@@ -0,0 +1,176 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from typing import Callable
+
+import torch
+import torch.nn as nn
+
+from direct.data.transforms import expand_operator, reduce_operator
+from direct.nn.unet import UnetModel2d
+
+
+class EndToEndVarNet(nn.Module):
+ """
+ End-to-End Variational Network as in https://arxiv.org/abs/2004.06688.
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable,
+ backward_operator: Callable,
+ num_layers: int,
+ regularizer_num_filters: int = 18,
+ regularizer_num_pull_layers: int = 4,
+ regularizer_dropout: float = 0.0,
+ in_channels: int = 2,
+ **kwargs,
+ ):
+ """
+ Parameters:
+ -----------
+ forward_operator : Callable
+ Forward Operator.
+ backward_operator : Callable
+ Backward Operator.
+ num_layers : int
+ Number of cascades.
+ regularizer_num_filters : int
+ Regularizer model number of filters.
+ regularizer_num_pull_layers : int
+ Regularizer model number of pulling layers.
+ regularizer_dropout : float
+ Regularizer model dropout probability.
+
+ """
+ super().__init__()
+ extra_keys = kwargs.keys()
+ for extra_key in extra_keys:
+ if extra_key not in [
+ "model_name",
+ ]:
+ raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
+
+ self.layers_list = nn.ModuleList()
+
+ for _ in range(num_layers):
+ self.layers_list.append(
+ EndToEndVarNetBlock(
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ regularizer_model=UnetModel2d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ num_filters=regularizer_num_filters,
+ num_pool_layers=regularizer_num_pull_layers,
+ dropout_probability=regularizer_dropout,
+ ),
+ )
+ )
+
+ def forward(
+ self, masked_kspace: torch.Tensor, sampling_mask: torch.Tensor, sensitivity_map: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Parameters
+ ----------
+ masked_kspace : torch.Tensor
+ Masked k-space of shape (N, coil, height, width, complex=2).
+ sampling_mask : torch.Tensor
+ Sampling mask of shape (N, 1, height, width, 1).
+ sensitivity_map : torch.Tensor
+ Sensitivity map of shape (N, coil, height, width, complex=2).
+
+ Returns
+ -------
+ kspace_prediction : torch.Tensor
+ K-space prediction of shape (N, coil, height, width, complex=2).
+ """
+
+ kspace_prediction = masked_kspace.clone()
+ for layer in self.layers_list:
+ kspace_prediction = layer(kspace_prediction, masked_kspace, sampling_mask, sensitivity_map)
+ return kspace_prediction
+
+
+class EndToEndVarNetBlock(nn.Module):
+ """
+ End-to-End Variational Network block.
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable,
+ backward_operator: Callable,
+ regularizer_model: nn.Module,
+ ):
+ """
+
+ Parameters:
+ -----------
+ forward_operator : Callable
+ Forward Operator.
+ backward_operator : Callable
+ Backward Operator.
+ regularizer_model : nn.Module
+ Regularizer model.
+ """
+ super().__init__()
+ self.regularizer_model = regularizer_model
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+ self.learning_rate = nn.Parameter(torch.tensor([1.0]))
+ self._coil_dim = 1
+ self._complex_dim = -1
+ self._spatial_dims = (2, 3)
+
+ def forward(
+ self,
+ current_kspace: torch.Tensor,
+ masked_kspace: torch.Tensor,
+ sampling_mask: torch.Tensor,
+ sensitivity_map: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+
+ Parameters
+ ----------
+ current_kspace : torch.Tensor
+ Current k-space prediction of shape (N, coil, height, width, complex=2).
+ masked_kspace : torch.Tensor
+ Masked k-space of shape (N, coil, height, width, complex=2).
+ sampling_mask : torch.Tensor
+ Sampling mask of shape (N, 1, height, width, 1).
+ sensitivity_map : torch.Tensor
+ Sensitivity map of shape (N, coil, height, width, complex=2).
+
+ Returns
+ -------
+ torch.Tensor
+ Next k-space prediction of shape (N, coil, height, width, complex=2).
+ """
+ kspace_error = torch.where(
+ sampling_mask == 0,
+ torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
+ current_kspace - masked_kspace,
+ )
+ regularization_term = torch.cat(
+ [
+ reduce_operator(
+ self.backward_operator(kspace, dim=self._spatial_dims), sensitivity_map, dim=self._coil_dim
+ )
+ for kspace in torch.split(current_kspace, 2, self._complex_dim)
+ ],
+ dim=self._complex_dim,
+ ).permute(0, 3, 1, 2)
+ regularization_term = self.regularizer_model(regularization_term).permute(0, 2, 3, 1)
+ regularization_term = torch.cat(
+ [
+ self.forward_operator(
+ expand_operator(image, sensitivity_map, dim=self._coil_dim), dim=self._spatial_dims
+ )
+ for image in torch.split(regularization_term, 2, self._complex_dim)
+ ],
+ dim=self._complex_dim,
+ )
+ return current_kspace - self.learning_rate * kspace_error + regularization_term
diff --git a/direct/nn/varnet/varnet_engine.py b/direct/nn/varnet/varnet_engine.py
new file mode 100644
index 00000000..2323f394
--- /dev/null
+++ b/direct/nn/varnet/varnet_engine.py
@@ -0,0 +1,473 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import time
+from collections import defaultdict
+from os import PathLike
+from typing import Callable, DefaultDict, Dict, List, Optional
+
+import numpy as np
+import torch
+from torch import nn
+from torch.cuda.amp import autocast
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+
+import direct.data.transforms as T
+from direct.config import BaseConfig
+from direct.engine import DoIterationOutput, Engine
+from direct.functionals import SSIMLoss
+from direct.utils import (
+ communication,
+ detach_dict,
+ dict_to_device,
+ merge_list_of_dicts,
+ multiply_function,
+ reduce_list_of_dicts,
+)
+from direct.utils.communication import reduce_tensor_dict
+
+
+class EndToEndVarNetEngine(Engine):
+ """
+ End-to-End Variational Network Engine.
+ """
+
+ def __init__(
+ self,
+ cfg: BaseConfig,
+ model: nn.Module,
+ device: int,
+ forward_operator: Optional[Callable] = None,
+ backward_operator: Optional[Callable] = None,
+ mixed_precision: bool = False,
+ **models: nn.Module,
+ ):
+ super().__init__(
+ cfg,
+ model,
+ device,
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ mixed_precision=mixed_precision,
+ **models,
+ )
+
+ self._complex_dim = -1
+ self._coil_dim = 1
+ self._spatial_dims = (2, 3)
+
+ def _do_iteration(
+ self,
+ data: Dict[str, torch.Tensor],
+ loss_fns: Optional[Dict[str, Callable]] = None,
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ ) -> DoIterationOutput:
+
+ # loss_fns can be done, e.g. during validation
+ if loss_fns is None:
+ loss_fns = {}
+
+ if regularizer_fns is None:
+ regularizer_fns = {}
+
+ loss_dicts = []
+ regularizer_dicts = []
+
+ data = dict_to_device(data, self.device)
+
+ # sensitivity_map of shape (batch, coil, height, width, complex=2)
+ sensitivity_map = data["sensitivity_map"]
+
+ if "sensitivity_model" in self.models:
+
+ # Move channels to first axis
+ sensitivity_map = data["sensitivity_map"].permute(
+ (0, 1, 4, 2, 3)
+ ) # shape (batch, coil, complex=2, height, width)
+
+ sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute(
+ (0, 1, 3, 4, 2)
+ ) # has channel last: shape (batch, coil, height, width, complex=2)
+
+ # The sensitivity map needs to be normalized such that
+ # So \sum_{i \in \text{coils}} S_i S_i^* = 1
+
+ sensitivity_map_norm = torch.sqrt(
+ ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim)
+ ) # shape (batch, height, width)
+ sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1)
+ data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm)
+
+ with autocast(enabled=self.mixed_precision):
+
+ output_kspace = self.model(
+ masked_kspace=data["masked_kspace"],
+ sampling_mask=data["sampling_mask"],
+ sensitivity_map=data["sensitivity_map"],
+ )
+
+ output_image = T.root_sum_of_squares(
+ self.backward_operator(output_kspace, dim=self._spatial_dims), dim=self._coil_dim
+ ) # shape (batch, height, width)
+
+ loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()}
+ regularizer_dict = {
+ k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys()
+ }
+
+ for key, value in loss_dict.items():
+ loss_dict[key] = value + loss_fns[key](
+ output_image,
+ **data,
+ reduction="mean",
+ )
+
+ for key, value in regularizer_dict.items():
+ regularizer_dict[key] = value + regularizer_fns[key](
+ output_image,
+ **data,
+ )
+
+ loss = sum(loss_dict.values()) + sum(regularizer_dict.values())
+
+ if self.model.training:
+ self._scaler.scale(loss).backward()
+
+ loss_dicts.append(detach_dict(loss_dict))
+ regularizer_dicts.append(
+ detach_dict(regularizer_dict)
+ ) # Need to detach dict as this is only used for logging.
+
+ # Add the loss dicts.
+ loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum")
+ regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum")
+
+ return DoIterationOutput(
+ output_image=output_image,
+ sensitivity_map=data["sensitivity_map"],
+ data_dict={**loss_dict, **regularizer_dict},
+ )
+
+ def build_loss(self, **kwargs) -> Dict:
+ # TODO: Cropper is a processing output tool.
+ def get_resolution(**data):
+ """Be careful that this will use the cropping size of the FIRST sample in the batch."""
+ return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None))
+
+ def l1_loss(source, reduction="mean", **data):
+ """
+ Calculate L1 loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l1_loss
+
+ def l2_loss(source, reduction="mean", **data):
+ """
+ Calculate L2 loss (MSE) given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l2_loss
+
+ def ssim_loss(source, reduction="mean", **data):
+ """
+ Calculate SSIM loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ if reduction != "mean":
+ raise AssertionError(
+ f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}."
+ )
+
+ source_abs, target_abs = self.cropper(source, data["target"], resolution)
+ data_range = torch.tensor([target_abs.max()], device=target_abs.device)
+
+ ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range)
+
+ return ssim_loss
+
+ # Build losses
+ loss_dict = {}
+ for curr_loss in self.cfg.training.loss.losses: # type: ignore
+ loss_fn = curr_loss.function
+ if loss_fn == "l1_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss)
+ elif loss_fn == "l2_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss)
+ elif loss_fn == "ssim_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss)
+ else:
+ raise ValueError(f"{loss_fn} not permissible.")
+
+ return loss_dict
+
+ @torch.no_grad()
+ def evaluate(
+ self,
+ data_loader: DataLoader,
+ loss_fns: Optional[Dict[str, Callable]],
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ crop: Optional[str] = None,
+ is_validation_process: bool = True,
+ ):
+ """
+ Validation process. Assumes that each batch only contains slices of the same volume *AND* that these
+ are sequentially ordered.
+
+ Parameters
+ ----------
+ data_loader : DataLoader
+ loss_fns : Dict[str, Callable], optional
+ regularizer_fns : Dict[str, Callable], optional
+ crop : str, optional
+ is_validation_process : bool
+
+ Returns
+ -------
+ loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ """
+ self.models_to_device()
+ self.models_validation_mode()
+ torch.cuda.empty_cache()
+
+ # Variables required for evaluation.
+ volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore
+
+ # filenames can be in the volume_indices attribute of the dataset
+ num_for_this_process = None
+ all_filenames = None
+ if hasattr(data_loader.dataset, "volume_indices"):
+ all_filenames = list(data_loader.dataset.volume_indices.keys())
+ num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys()))
+ self.logger.info(
+ f"Reconstructing a total of {len(all_filenames)} volumes. "
+ f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})."
+ )
+
+ filenames_seen = 0
+ reconstruction_output: DefaultDict = defaultdict(list)
+ if is_validation_process:
+ targets_output: DefaultDict = defaultdict(list)
+ val_losses = []
+ val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict)
+ last_filename = None
+
+ # Container to for the slices which can be visualized in TensorBoard.
+ visualize_slices: List[np.ndarray] = []
+ visualize_target: List[np.ndarray] = []
+ # visualizations = {}
+
+ extra_visualization_keys = (
+ self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore
+ )
+
+ # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
+ # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
+ # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only
+ # contains data from one volume.
+ time_start = time.time()
+
+ for iter_idx, data in enumerate(data_loader):
+ filenames = data.pop("filename")
+ if len(set(filenames)) != 1:
+ raise ValueError(
+ f"Expected a batch during validation to only contain filenames of one case. "
+ f"Got {set(filenames)}."
+ )
+
+ slice_nos = data.pop("slice_no")
+ scaling_factors = data["scaling_factor"]
+
+ resolution = self.compute_resolution(
+ key=self.cfg.validation.crop, # type: ignore
+ reconstruction_size=data.get("reconstruction_size", None),
+ )
+
+ # Compute output and loss.
+ iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns)
+ output = iteration_output.output_image
+ loss_dict = iteration_output.data_dict
+
+ loss_dict = detach_dict(loss_dict)
+ output = output.detach()
+ val_losses.append(loss_dict)
+
+ # Output is complex-valued, and has to be cropped. This holds for both output and target.
+ # Output has shape (batch, complex, height, width)
+ output_abs = self.process_output(
+ output,
+ scaling_factors,
+ resolution=resolution,
+ )
+
+ if is_validation_process:
+ # Target has shape (batch, height, width)
+ target_abs = self.process_output(
+ data["target"].detach(),
+ scaling_factors,
+ resolution=resolution,
+ )
+ for key in extra_visualization_keys:
+ curr_data = data[key].detach()
+ # Here we need to discover which keys are actually normalized or not
+ # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23
+
+ del output # Explicitly call delete to clear memory.
+
+ # Aggregate volumes to be able to compute the metrics on complete volumes.
+ for idx, filename in enumerate(filenames):
+ if last_filename is None:
+ last_filename = filename # First iteration last_filename is not set.
+
+ curr_slice = output_abs[idx].detach()
+ slice_no = int(slice_nos[idx].numpy())
+
+ reconstruction_output[filename].append((slice_no, curr_slice.cpu()))
+
+ if is_validation_process:
+ targets_output[filename].append((slice_no, target_abs[idx].cpu()))
+
+ is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"])
+ reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch]
+ for condition in reconstruction_conditions:
+ if condition:
+ filenames_seen += 1
+
+ # Now we can ditch the reconstruction dict by reconstructing the volume,
+ # will take too much memory otherwise.
+ volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]])
+ if is_validation_process:
+ target = torch.stack([_[1] for _ in targets_output[last_filename]])
+ curr_metrics = {
+ metric_name: metric_fn(target, volume)
+ for metric_name, metric_fn in volume_metrics.items()
+ }
+ val_volume_metrics[last_filename] = curr_metrics
+ # Log the center slice of the volume
+ if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore
+ visualize_slices.append(volume[volume.shape[0] // 2])
+ visualize_target.append(target[target.shape[0] // 2])
+
+ # Delete outputs from memory, and recreate dictionary.
+ # This is not needed when not in validation as we are actually interested
+ # in the iteration output.
+ del targets_output[last_filename]
+ del reconstruction_output[last_filename]
+
+ if all_filenames:
+ log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
+ else:
+ log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"
+
+ self.logger.info(
+ f"{log_prefix} {last_filename}"
+ f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
+ )
+ # restart timer
+ time_start = time.time()
+ last_filename = filename
+
+ # Average loss dict
+ loss_dict = reduce_list_of_dicts(val_losses)
+ reduce_tensor_dict(loss_dict)
+
+ communication.synchronize()
+ torch.cuda.empty_cache()
+
+ all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics))
+ if not is_validation_process:
+ return loss_dict, reconstruction_output
+
+ return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ def process_output(self, data, scaling_factors=None, resolution=None):
+ # data is of shape (batch, complex=2, height, width)
+ if scaling_factors is not None:
+ data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device)
+
+ data = T.modulus_if_complex(data)
+
+ if len(data.shape) == 3: # (batch, height, width)
+ data = data.unsqueeze(1) # Added channel dimension.
+
+ if resolution is not None:
+ data = T.center_crop(data, resolution).contiguous()
+
+ return data
+
+ @staticmethod
+ def compute_resolution(key, reconstruction_size):
+ if key == "header":
+ # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over
+ # batches.
+ resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size]
+ # The volume sampler should give validation indices belonging to the *same* volume, so it should be
+ # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y).
+ resolution = [_[0] for _ in resolution][:-1]
+ elif key == "training":
+ resolution = key
+ elif not key:
+ resolution = None
+ else:
+ raise ValueError(
+ "Cropping should be either set to `header` to get the values from the header or "
+ "`training` to take the same value as training."
+ )
+ return resolution
+
+ def cropper(self, source, target, resolution):
+ """
+ 2D source/target cropper
+
+ Parameters:
+ -----------
+ Source has shape (batch, height, width)
+ Target has shape (batch, height, width)
+
+ """
+
+ if not resolution or all(_ == 0 for _ in resolution):
+ return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension.
+
+ source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension.
+ target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension.
+
+ return source_abs, target_abs
+
+ def compute_model_per_coil(self, model_name, data):
+ """
+ Computes model per coil.
+ """
+ # data is of shape (batch, coil, complex=2, height, width)
+ output = []
+
+ for idx in range(data.size(self._coil_dim)):
+ subselected_data = data.select(self._coil_dim, idx)
+ output.append(self.models[model_name](subselected_data))
+ output = torch.stack(output, dim=self._coil_dim)
+
+ # output is of shape (batch, coil, complex=2, height, width)
+ return output
diff --git a/direct/nn/xpdnet/__init__.py b/direct/nn/xpdnet/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/xpdnet/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/xpdnet/config.py b/direct/nn/xpdnet/config.py
new file mode 100644
index 00000000..0bf070a3
--- /dev/null
+++ b/direct/nn/xpdnet/config.py
@@ -0,0 +1,26 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from dataclasses import dataclass
+
+from direct.config.defaults import ModelConfig
+
+
+@dataclass
+class XPDNetConfig(ModelConfig):
+ num_primal: int = 5
+ num_dual: int = 1
+ num_iter: int = 10
+ use_primal_only: bool = True
+ kspace_model_architecture: str = "CONV"
+ dual_conv_hidden_channels: int = 16
+ dual_conv_n_convs: int = 4
+ dual_conv_batchnorm: bool = False
+ dual_didn_hidden_channels: int = 64
+ dual_didn_num_dubs: int = 6
+ dual_didn_num_convs_recon: int = 9
+ mwcnn_hidden_channels: int = 16
+ mwcnn_num_scales: int = 4
+ mwcnn_bias: bool = True
+ mwcnn_batchnorm: bool = False
+ normalize: bool = False
diff --git a/direct/nn/xpdnet/tests/__init__.py b/direct/nn/xpdnet/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/nn/xpdnet/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/nn/xpdnet/tests/test_xpdnet.py b/direct/nn/xpdnet/tests/test_xpdnet.py
new file mode 100644
index 00000000..7fb1aaf7
--- /dev/null
+++ b/direct/nn/xpdnet/tests/test_xpdnet.py
@@ -0,0 +1,76 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import pytest
+import torch
+
+from direct.data.transforms import fft2, ifft2
+from direct.nn.xpdnet.xpdnet import XPDNet
+
+
+def create_input(shape):
+
+ data = torch.rand(shape).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [3, 3, 32, 32],
+ ],
+)
+@pytest.mark.parametrize(
+ "num_iter",
+ [2, 3],
+)
+@pytest.mark.parametrize(
+ "num_primal",
+ [2, 3],
+)
+@pytest.mark.parametrize(
+ "image_model_architecture",
+ ["MWCNN"],
+)
+@pytest.mark.parametrize(
+ "primal_only, kspace_model_architecture, num_dual",
+ [
+ [True, None, 1],
+ [False, "CONV", 3],
+ [False, "DIDN", 2],
+ ],
+)
+@pytest.mark.parametrize(
+ "normalize",
+ [True, False],
+)
+def test_xpdnet(
+ shape,
+ num_iter,
+ num_primal,
+ num_dual,
+ image_model_architecture,
+ kspace_model_architecture,
+ primal_only,
+ normalize,
+):
+ model = XPDNet(
+ fft2,
+ ifft2,
+ num_iter=num_iter,
+ num_primal=num_primal,
+ num_dual=num_dual,
+ image_model_architecture=image_model_architecture,
+ kspace_model_architecture=kspace_model_architecture,
+ use_primal_only=primal_only,
+ normalize=normalize,
+ ).cpu()
+
+ kspace = create_input(shape + [2]).cpu()
+ sens = create_input(shape + [2]).cpu()
+ mask = create_input([shape[0]] + [1] + shape[2:] + [1]).round().int().cpu()
+
+ out = model(kspace, mask, sens)
+
+ assert list(out.shape) == [shape[0]] + shape[2:] + [2]
diff --git a/direct/nn/xpdnet/xpdnet.py b/direct/nn/xpdnet/xpdnet.py
new file mode 100644
index 00000000..a81987df
--- /dev/null
+++ b/direct/nn/xpdnet/xpdnet.py
@@ -0,0 +1,127 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+from typing import Callable, Optional
+
+import torch.nn as nn
+
+from direct.nn.crossdomain.crossdomain import CrossDomainNetwork
+from direct.nn.crossdomain.multicoil import MultiCoil
+from direct.nn.conv.conv import Conv2d
+from direct.nn.didn.didn import DIDN
+from direct.nn.mwcnn.mwcnn import MWCNN
+
+
+class XPDNet(CrossDomainNetwork):
+ """
+ XPDNet as implemented in https://arxiv.org/abs/2010.07290.
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable,
+ backward_operator: Callable,
+ num_primal: int = 5,
+ num_dual: int = 1,
+ num_iter: int = 10,
+ use_primal_only: bool = True,
+ image_model_architecture: str = "MWCNN",
+ kspace_model_architecture: Optional[str] = None,
+ normalize: bool = False,
+ **kwargs,
+ ):
+ """
+
+ Parameters
+ ----------
+ forward_operator : Callable
+ Forward Operator.
+ backward_operator : Callable
+ Backward Operator.
+ num_primal : int
+ Number of primal networks.
+ num_dual : int
+ Number of dual networks.
+ num_iter : int
+ Number of unrolled iterations.
+ use_primal_only : bool
+ If set to True no dual-kspace model is used. Default: True.
+ image_model_architecture : str
+ Primal-image model architecture. Currently only implemented for MWCNN. Default: 'MWCNN'.
+ kspace_model_architecture : str
+ Dual-kspace model architecture. Currently only implemented for CONV and DIDN.
+ normalize : bool
+ Normalize input. Default: False.
+ kwargs : dict
+ Keyword arguments for model architectures.
+ """
+ if use_primal_only:
+ kspace_model_list = None
+ num_dual = 1
+ elif kspace_model_architecture == "CONV":
+ kspace_model_list = nn.ModuleList(
+ [
+ MultiCoil(
+ Conv2d(
+ 2 * (num_dual + num_primal + 1),
+ 2 * num_dual,
+ kwargs.get("dual_conv_hidden_channels", 16),
+ kwargs.get("dual_conv_n_convs", 4),
+ batchnorm=kwargs.get("dual_conv_batchnorm", False),
+ )
+ )
+ for _ in range(num_iter)
+ ]
+ )
+ elif kspace_model_architecture == "DIDN":
+ kspace_model_list = nn.ModuleList(
+ [
+ MultiCoil(
+ DIDN(
+ in_channels=2 * (num_dual + num_primal + 1),
+ out_channels=2 * num_dual,
+ hidden_channels=kwargs.get("dual_didn_hidden_channels", 16),
+ num_dubs=kwargs.get("dual_didn_num_dubs", 6),
+ num_convs_recon=kwargs.get("dual_didn_num_convs_recon", 9),
+ )
+ )
+ for _ in range(num_iter)
+ ]
+ )
+
+ else:
+ raise NotImplementedError(
+ f"XPDNet is currently implemented for kspace_model_architecture == 'CONV' or 'DIDN'."
+ f"Got kspace_model_architecture == {kspace_model_architecture}."
+ )
+ if image_model_architecture == "MWCNN":
+ image_model_list = nn.ModuleList(
+ [
+ nn.Sequential(
+ MWCNN(
+ input_channels=2 * (num_primal + num_dual),
+ first_conv_hidden_channels=kwargs.get("mwcnn_hidden_channels", 32),
+ num_scales=kwargs.get("mwcnn_num_scales", 4),
+ bias=kwargs.get("mwcnn_bias", False),
+ batchnorm=kwargs.get("mwcnn_batchnorm", False),
+ ),
+ nn.Conv2d(2 * (num_primal + num_dual), 2 * num_primal, kernel_size=3, padding=1),
+ )
+ for _ in range(num_iter)
+ ]
+ )
+ else:
+ raise NotImplementedError(
+ f"XPDNet is currently implemented only with image_model_architecture == 'MWCNN'."
+ f"Got {image_model_architecture}."
+ )
+ super().__init__(
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ image_model_list=image_model_list,
+ kspace_model_list=kspace_model_list,
+ domain_sequence="KI" * num_iter,
+ image_buffer_size=num_primal,
+ kspace_buffer_size=num_dual,
+ normalize_image=normalize,
+ )
diff --git a/direct/nn/xpdnet/xpdnet_engine.py b/direct/nn/xpdnet/xpdnet_engine.py
new file mode 100644
index 00000000..5bae217b
--- /dev/null
+++ b/direct/nn/xpdnet/xpdnet_engine.py
@@ -0,0 +1,472 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+
+import time
+from collections import defaultdict
+from os import PathLike
+from typing import Callable, DefaultDict, Dict, List, Optional
+
+import numpy as np
+import torch
+from torch import nn
+from torch.cuda.amp import autocast
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+
+import direct.data.transforms as T
+from direct.config import BaseConfig
+from direct.engine import DoIterationOutput, Engine
+from direct.functionals import SSIMLoss
+from direct.utils import (
+ communication,
+ detach_dict,
+ dict_to_device,
+ merge_list_of_dicts,
+ multiply_function,
+ reduce_list_of_dicts,
+)
+from direct.utils.communication import reduce_tensor_dict
+
+
+class XPDNetEngine(Engine):
+ """
+ XPDNet Engine.
+ """
+
+ def __init__(
+ self,
+ cfg: BaseConfig,
+ model: nn.Module,
+ device: int,
+ forward_operator: Optional[Callable] = None,
+ backward_operator: Optional[Callable] = None,
+ mixed_precision: bool = False,
+ **models: nn.Module,
+ ):
+ super().__init__(
+ cfg,
+ model,
+ device,
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ mixed_precision=mixed_precision,
+ **models,
+ )
+
+ self._complex_dim = -1
+ self._coil_dim = 1
+ self._spatial_dims = (2, 3)
+
+ def _do_iteration(
+ self,
+ data: Dict[str, torch.Tensor],
+ loss_fns: Optional[Dict[str, Callable]] = None,
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ ) -> DoIterationOutput:
+
+ # loss_fns can be done, e.g. during validation
+ if loss_fns is None:
+ loss_fns = {}
+
+ if regularizer_fns is None:
+ regularizer_fns = {}
+
+ loss_dicts = []
+ regularizer_dicts = []
+
+ data = dict_to_device(data, self.device)
+
+ # sensitivity_map of shape (batch, coil, height, width, complex=2)
+ sensitivity_map = data["sensitivity_map"]
+
+ if "sensitivity_model" in self.models:
+
+ # Move channels to first axis
+ sensitivity_map = data["sensitivity_map"].permute(
+ (0, 1, 4, 2, 3)
+ ) # shape (batch, coil, complex=2, height, width)
+
+ sensitivity_map = self.compute_model_per_coil("sensitivity_model", sensitivity_map).permute(
+ (0, 1, 3, 4, 2)
+ ) # has channel last: shape (batch, coil, height, width, complex=2)
+
+ # The sensitivity map needs to be normalized such that
+ # So \sum_{i \in \text{coils}} S_i S_i^* = 1
+
+ sensitivity_map_norm = torch.sqrt(
+ ((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim)
+ ) # shape (batch, height, width)
+ sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1)
+ data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm)
+
+ with autocast(enabled=self.mixed_precision):
+
+ output_image = self.model(
+ masked_kspace=data["masked_kspace"],
+ sampling_mask=data["sampling_mask"],
+ sensitivity_map=data["sensitivity_map"],
+ scaling_factor=data["scaling_factor"],
+ ) # shape (batch, height, width, complex=2)
+
+ output_image = T.modulus(output_image) # shape (batch, height, width)
+
+ loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()}
+ regularizer_dict = {
+ k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys()
+ }
+
+ for key, value in loss_dict.items():
+ loss_dict[key] = value + loss_fns[key](
+ output_image,
+ **data,
+ reduction="mean",
+ )
+
+ for key, value in regularizer_dict.items():
+ regularizer_dict[key] = value + regularizer_fns[key](
+ output_image,
+ **data,
+ )
+
+ loss = sum(loss_dict.values()) + sum(regularizer_dict.values())
+
+ if self.model.training:
+ self._scaler.scale(loss).backward()
+
+ loss_dicts.append(detach_dict(loss_dict))
+ regularizer_dicts.append(
+ detach_dict(regularizer_dict)
+ ) # Need to detach dict as this is only used for logging.
+
+ # Add the loss dicts.
+ loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum")
+ regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum")
+
+ return DoIterationOutput(
+ output_image=output_image,
+ sensitivity_map=data["sensitivity_map"],
+ data_dict={**loss_dict, **regularizer_dict},
+ )
+
+ def build_loss(self, **kwargs) -> Dict:
+ # TODO: Cropper is a processing output tool.
+ def get_resolution(**data):
+ """Be careful that this will use the cropping size of the FIRST sample in the batch."""
+ return self.compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None))
+
+ def l1_loss(source, reduction="mean", **data):
+ """
+ Calculate L1 loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l1_loss = F.l1_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l1_loss
+
+ def l2_loss(source, reduction="mean", **data):
+ """
+ Calculate L2 loss (MSE) given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ l2_loss = F.mse_loss(*self.cropper(source, data["target"], resolution), reduction=reduction)
+
+ return l2_loss
+
+ def ssim_loss(source, reduction="mean", **data):
+ """
+ Calculate SSIM loss given source and target.
+
+ Parameters:
+ -----------
+ Source: shape (batch, complex=2, height, width)
+ Data: Contains key "target" with value a tensor of shape (batch, height, width)
+
+ """
+ resolution = get_resolution(**data)
+ if reduction != "mean":
+ raise AssertionError(
+ f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}."
+ )
+
+ source_abs, target_abs = self.cropper(source, data["target"], resolution)
+ data_range = torch.tensor([target_abs.max()], device=target_abs.device)
+
+ ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range)
+
+ return ssim_loss
+
+ # Build losses
+ loss_dict = {}
+ for curr_loss in self.cfg.training.loss.losses: # type: ignore
+ loss_fn = curr_loss.function
+ if loss_fn == "l1_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss)
+ elif loss_fn == "l2_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss)
+ elif loss_fn == "ssim_loss":
+ loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss)
+ else:
+ raise ValueError(f"{loss_fn} not permissible.")
+
+ return loss_dict
+
+ @torch.no_grad()
+ def evaluate(
+ self,
+ data_loader: DataLoader,
+ loss_fns: Optional[Dict[str, Callable]],
+ regularizer_fns: Optional[Dict[str, Callable]] = None,
+ crop: Optional[str] = None,
+ is_validation_process: bool = True,
+ ):
+ """
+ Validation process. Assumes that each batch only contains slices of the same volume *AND* that these
+ are sequentially ordered.
+
+ Parameters
+ ----------
+ data_loader : DataLoader
+ loss_fns : Dict[str, Callable], optional
+ regularizer_fns : Dict[str, Callable], optional
+ crop : str, optional
+ is_validation_process : bool
+
+ Returns
+ -------
+ loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ """
+ self.models_to_device()
+ self.models_validation_mode()
+ torch.cuda.empty_cache()
+
+ # Variables required for evaluation.
+ volume_metrics = self.build_metrics(self.cfg.validation.metrics) # type: ignore
+
+ # filenames can be in the volume_indices attribute of the dataset
+ num_for_this_process = None
+ all_filenames = None
+ if hasattr(data_loader.dataset, "volume_indices"):
+ all_filenames = list(data_loader.dataset.volume_indices.keys())
+ num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys()))
+ self.logger.info(
+ f"Reconstructing a total of {len(all_filenames)} volumes. "
+ f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})."
+ )
+
+ filenames_seen = 0
+ reconstruction_output: DefaultDict = defaultdict(list)
+ if is_validation_process:
+ targets_output: DefaultDict = defaultdict(list)
+ val_losses = []
+ val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict)
+ last_filename = None
+
+ # Container to for the slices which can be visualized in TensorBoard.
+ visualize_slices: List[np.ndarray] = []
+ visualize_target: List[np.ndarray] = []
+ # visualizations = {}
+
+ extra_visualization_keys = (
+ self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else [] # type: ignore
+ )
+
+ # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
+ # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
+ # that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only
+ # contains data from one volume.
+ time_start = time.time()
+
+ for iter_idx, data in enumerate(data_loader):
+ filenames = data.pop("filename")
+ if len(set(filenames)) != 1:
+ raise ValueError(
+ f"Expected a batch during validation to only contain filenames of one case. "
+ f"Got {set(filenames)}."
+ )
+
+ slice_nos = data.pop("slice_no")
+ scaling_factors = data["scaling_factor"]
+
+ resolution = self.compute_resolution(
+ key=self.cfg.validation.crop, # type: ignore
+ reconstruction_size=data.get("reconstruction_size", None),
+ )
+
+ # Compute output and loss.
+ iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns)
+ output = iteration_output.output_image
+ loss_dict = iteration_output.data_dict
+
+ loss_dict = detach_dict(loss_dict)
+ output = output.detach()
+ val_losses.append(loss_dict)
+
+ # Output is complex-valued, and has to be cropped. This holds for both output and target.
+ # Output has shape (batch, complex, height, width)
+ output_abs = self.process_output(
+ output,
+ scaling_factors,
+ resolution=resolution,
+ )
+
+ if is_validation_process:
+ # Target has shape (batch, height, width)
+ target_abs = self.process_output(
+ data["target"].detach(),
+ scaling_factors,
+ resolution=resolution,
+ )
+ for key in extra_visualization_keys:
+ curr_data = data[key].detach()
+ # Here we need to discover which keys are actually normalized or not
+ # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23
+
+ del output # Explicitly call delete to clear memory.
+
+ # Aggregate volumes to be able to compute the metrics on complete volumes.
+ for idx, filename in enumerate(filenames):
+ if last_filename is None:
+ last_filename = filename # First iteration last_filename is not set.
+
+ curr_slice = output_abs[idx].detach()
+ slice_no = int(slice_nos[idx].numpy())
+
+ reconstruction_output[filename].append((slice_no, curr_slice.cpu()))
+
+ if is_validation_process:
+ targets_output[filename].append((slice_no, target_abs[idx].cpu()))
+
+ is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"])
+ reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch]
+ for condition in reconstruction_conditions:
+ if condition:
+ filenames_seen += 1
+
+ # Now we can ditch the reconstruction dict by reconstructing the volume,
+ # will take too much memory otherwise.
+ volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]])
+ if is_validation_process:
+ target = torch.stack([_[1] for _ in targets_output[last_filename]])
+ curr_metrics = {
+ metric_name: metric_fn(target, volume)
+ for metric_name, metric_fn in volume_metrics.items()
+ }
+ val_volume_metrics[last_filename] = curr_metrics
+ # Log the center slice of the volume
+ if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore
+ visualize_slices.append(volume[volume.shape[0] // 2])
+ visualize_target.append(target[target.shape[0] // 2])
+
+ # Delete outputs from memory, and recreate dictionary.
+ # This is not needed when not in validation as we are actually interested
+ # in the iteration output.
+ del targets_output[last_filename]
+ del reconstruction_output[last_filename]
+
+ if all_filenames:
+ log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
+ else:
+ log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"
+
+ self.logger.info(
+ f"{log_prefix} {last_filename}"
+ f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
+ )
+ # restart timer
+ time_start = time.time()
+ last_filename = filename
+
+ # Average loss dict
+ loss_dict = reduce_list_of_dicts(val_losses)
+ reduce_tensor_dict(loss_dict)
+
+ communication.synchronize()
+ torch.cuda.empty_cache()
+
+ all_gathered_metrics = merge_list_of_dicts(communication.all_gather(val_volume_metrics))
+ if not is_validation_process:
+ return loss_dict, reconstruction_output
+
+ return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
+
+ def process_output(self, data, scaling_factors=None, resolution=None):
+ # data is of shape (batch, complex=2, height, width)
+ if scaling_factors is not None:
+ data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device)
+
+ data = T.modulus_if_complex(data)
+
+ if len(data.shape) == 3: # (batch, height, width)
+ data = data.unsqueeze(1) # Added channel dimension.
+
+ if resolution is not None:
+ data = T.center_crop(data, resolution).contiguous()
+
+ return data
+
+ @staticmethod
+ def compute_resolution(key, reconstruction_size):
+ if key == "header":
+ # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over
+ # batches.
+ resolution = [_.detach().cpu().numpy().tolist() for _ in reconstruction_size]
+ # The volume sampler should give validation indices belonging to the *same* volume, so it should be
+ # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y).
+ resolution = [_[0] for _ in resolution][:-1]
+ elif key == "training":
+ resolution = key
+ elif not key:
+ resolution = None
+ else:
+ raise ValueError(
+ "Cropping should be either set to `header` to get the values from the header or "
+ "`training` to take the same value as training."
+ )
+ return resolution
+
+ def cropper(self, source, target, resolution):
+ """
+ 2D source/target cropper
+
+ Parameters:
+ -----------
+ Source has shape (batch, height, width)
+ Target has shape (batch, height, width)
+
+ """
+
+ if not resolution or all(_ == 0 for _ in resolution):
+ return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension.
+
+ source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension.
+ target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension.
+
+ return source_abs, target_abs
+
+ def compute_model_per_coil(self, model_name, data):
+ """
+ Computes model per coil.
+ """
+ # data is of shape (batch, coil, complex=2, height, width)
+ output = []
+
+ for idx in range(data.size(self._coil_dim)):
+ subselected_data = data.select(self._coil_dim, idx)
+ output.append(self.models[model_name](subselected_data))
+ output = torch.stack(output, dim=self._coil_dim)
+
+ # output is of shape (batch, coil, complex=2, height, width)
+ return output
diff --git a/direct/utils/__init__.py b/direct/utils/__init__.py
index 820739d5..f46254f4 100644
--- a/direct/utils/__init__.py
+++ b/direct/utils/__init__.py
@@ -18,6 +18,52 @@
logger = logging.getLogger(__name__)
+def is_complex_data(data: torch.Tensor, complex_last: bool = True) -> bool:
+ """
+ Returns True if data is a complex tensor, i.e. has a complex axis of dimension 2, and False otherwise.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ For 2D data the shape is assumed ([batch], [coil], height, width, [complex])
+ or ([batch], [coil], [complex], height, width).
+ For 3D data the shape is assumed ([batch], [coil], slice, height, width, [complex])
+ or ([batch], [coil], [complex], slice, height, width).
+ complex_last : bool
+ If true, will require complex axis to be at the last axis.
+ Returns
+ -------
+
+ """
+ if 2 not in data.shape:
+ return False
+ if complex_last:
+ if data.size(-1) != 2:
+ return False
+ else:
+ if data.ndim == 6:
+ if data.size(2) != 2 and data.size(-1) != 2: # (B, C, 2, S, H, 2) or (B, C, S, H, W, 2)
+ return False
+
+ elif data.ndim == 5:
+ # (B, 2, S, H, W) or (B, C, 2, H, W) or (B, S, H, W, 2) or (B, C, H, W, 2)
+ if data.size(1) != 2 and data.size(2) != 2 and data.size(-1) != 2:
+ return False
+
+ elif data.ndim == 4:
+ if data.size(1) != 2 and data.size(-1) != 2: # (B, 2, H, W) or (B, H, W, 2) or (S, H, W, 2)
+ return False
+
+ elif data.ndim == 3:
+ if data.size(-1) != 2: # (H, W, 2)
+ return False
+
+ else:
+ raise ValueError(f"Not compatible number of dimensions for complex data. Got {data.ndim}.")
+
+ return True
+
+
def is_power_of_two(number: int) -> bool:
"""Check if input is a power of 2
@@ -404,9 +450,9 @@ def chunks(list_to_chunk, number_of_chunks):
From https://stackoverflow.com/a/54802737
"""
d, r = divmod(len(list_to_chunk), number_of_chunks)
- for i in range(number_of_chunks):
- si = (d + 1) * (i if i < r else r) + d * (0 if i < r else i - r)
- yield list_to_chunk[si : si + (d + 1 if i < r else d)]
+ for idx in range(number_of_chunks):
+ si = (d + 1) * (idx if idx < r else r) + d * (0 if idx < r else idx - r)
+ yield list_to_chunk[si : si + (d + 1 if idx < r else d)]
def remove_keys(input_dict, keys):
diff --git a/direct/utils/asserts.py b/direct/utils/asserts.py
index 2c98423e..763ec751 100644
--- a/direct/utils/asserts.py
+++ b/direct/utils/asserts.py
@@ -3,6 +3,8 @@
import inspect
from typing import List
+from direct.utils import is_complex_data
+
import torch
@@ -45,7 +47,7 @@ def assert_same_shape(data_list: List[torch.Tensor]):
def assert_complex(data: torch.Tensor, complex_last: bool = True) -> None:
"""
- Assert if a tensor is a complex named tensor.
+ Assert if a tensor is a complex tensor.
Parameters
----------
@@ -61,29 +63,5 @@ def assert_complex(data: torch.Tensor, complex_last: bool = True) -> None:
"""
# TODO: This is because ifft and fft or torch expect the last dimension to represent the complex axis.
-
- if 2 not in data.shape:
- raise ValueError(f"No complex dimension (2) found. Got shape {data.shape}.")
- if complex_last:
- if data.size(-1) != 2:
- raise ValueError(f"Last dimension assumed to be 2 (complex valued). Got {data.size(-1)}.")
- else:
- if data.ndim == 6 or data.ndim == 3:
- if data.size(1) != 2 and data.size(-1) != 2:
- raise ValueError(
- f"Complex dimension assumed to be 2 (complex valued), but not found in shape {data.shape}."
- )
-
- elif data.ndim == 5:
- if data.size(1) != 2 and data.size(2) != 2 and data.size(-1) != 2:
- raise ValueError(
- f"Complex dimension assumed to be 2 (complex valued), but not found in shape {data.shape}."
- )
-
- elif data.ndim == 4:
- if data.size(1) != 2 and data.size(-1) != 2:
- raise ValueError(
- f"Complex dimension assumed to be 2 (complex valued), but not found in shape {data.shape}."
- )
- else:
- raise ValueError(f"Data of shape {data.shape} is not complex.")
+ if not is_complex_data(data, complex_last):
+ raise ValueError(f"Complex dimension assumed to be 2 (complex valued), but not found in shape {data.shape}.")
diff --git a/direct/utils/events.py b/direct/utils/events.py
index af5d9220..6a0042f6 100644
--- a/direct/utils/events.py
+++ b/direct/utils/events.py
@@ -390,7 +390,8 @@ def step(self):
correct iteration number.
"""
self._iter += 1
- self._latest_scalars = {}
+ # TODO: This clears validation metrics.
+ # self._latest_scalars = {}
@property
def vis_data(self):
diff --git a/direct/utils/logging.py b/direct/utils/logging.py
index 3b5425a5..c2b3b9f7 100644
--- a/direct/utils/logging.py
+++ b/direct/utils/logging.py
@@ -38,7 +38,7 @@ def setup(
root = logging.getLogger()
root.setLevel(log_level)
- for name in logging.root.manager.loggerDict: # type: ignore
+ for name in logging.root.manager.loggerDict: # pylint: disable = E1101 # type: ignore
if name.startswith("torch"):
logging.getLogger(name).setLevel("WARNING")
diff --git a/direct/utils/tests/__init__.py b/direct/utils/tests/__init__.py
new file mode 100644
index 00000000..941752c9
--- /dev/null
+++ b/direct/utils/tests/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
diff --git a/direct/utils/tests/test_utils.py b/direct/utils/tests/test_utils.py
new file mode 100644
index 00000000..da0ba2d1
--- /dev/null
+++ b/direct/utils/tests/test_utils.py
@@ -0,0 +1,45 @@
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+"""Tests for the direct.utils module"""
+
+import numpy as np
+import pytest
+import torch
+
+from direct.utils import is_power_of_two, is_complex_data
+from direct.data.transforms import tensor_to_complex_numpy
+
+
+def create_input(shape):
+ data = np.random.randn(*shape).copy()
+ data = torch.from_numpy(data).float()
+
+ return data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ [3, 3, 2],
+ [5, 8, 4, 2],
+ [5, 2, 8, 4],
+ [3, 5, 8, 4, 2],
+ [3, 5, 2, 8, 4],
+ [3, 2, 5, 8, 4],
+ [3, 3, 5, 8, 4, 2],
+ [3, 3, 2, 5, 8, 4],
+ ],
+)
+def test_is_complex_data(shape):
+ data = create_input(shape)
+
+ assert is_complex_data(data, False)
+
+
+@pytest.mark.parametrize(
+ "num",
+ [1, 2, 4, 32, 128, 1024],
+)
+def test_is_power_of_two(num):
+
+ assert is_power_of_two(num)
diff --git a/projects/calgary_campinas/configs/base_jointicnet.yaml b/projects/calgary_campinas/configs/base_jointicnet.yaml
new file mode 100644
index 00000000..a97f150d
--- /dev/null
+++ b/projects/calgary_campinas/configs/base_jointicnet.yaml
@@ -0,0 +1,100 @@
+physics:
+ forward_operator: fft2(centered=False)
+ backward_operator: ifft2(centered=False)
+training:
+ datasets:
+ # Two datasets, only difference is the shape, so the data can be collated for larger batches
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x170_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x180_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ batch_size: 4 # This is the batch size per GPU!
+ optimizer: Adam
+ lr: 0.0005
+ weight_decay: 0.0
+ lr_step_size: 50000
+ lr_gamma: 0.2
+ lr_warmup_iter: 1000
+ num_iterations: 500000
+ gradient_steps: 1
+ gradient_clipping: 0.0
+ gradient_debug: false
+ checkpointer:
+ checkpoint_steps: 500
+ validation_steps: 500
+ loss:
+ crop: null
+ losses:
+ - function: l1_loss
+ multiplier: 1.0
+ - function: ssim_loss
+ multiplier: 1.0
+validation:
+ datasets:
+ # Twice the same dataset but a different acceleration factor
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5]
+ crop_outer_slices: true
+ text_description: 5x # Description for logging
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [10]
+ crop_outer_slices: true
+ text_description: 10x # Description for logging
+ crop: null # This sets the cropping for the DoIterationOutput
+ metrics: # These are obtained from direct.functionals
+ - calgary_campinas_psnr
+ - calgary_campinas_ssim
+ - calgary_campinas_vif
+model:
+ model_name: jointicnet.jointicnet.JointICNet
+ num_iter: 12
+ use_norm_unet: True
+ image_unet_num_filters: 32
+ kspace_unet_num_filters: 32
+ sens_unet_num_filters: 8
+
+logging:
+ tensorboard:
+ num_images: 4
+inference:
+ batch_size: 8
+ dataset:
+ name: CalgaryCampinas
+ crop_outer_slices: true
+ text_description: inference
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
diff --git a/projects/calgary_campinas/configs/base_kikinet.yaml b/projects/calgary_campinas/configs/base_kikinet.yaml
new file mode 100644
index 00000000..a657dbaa
--- /dev/null
+++ b/projects/calgary_campinas/configs/base_kikinet.yaml
@@ -0,0 +1,110 @@
+physics:
+ forward_operator: fft2(centered=False)
+ backward_operator: ifft2(centered=False)
+training:
+ datasets:
+ # Two datasets, only difference is the shape, so the data can be collated for larger batches
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x170_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x180_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ batch_size: 8 # This is the batch size per GPU!
+ optimizer: Adam
+ lr: 0.0002
+ weight_decay: 0.0
+ lr_step_size: 50000
+ lr_gamma: 0.2
+ lr_warmup_iter: 1000
+ num_iterations: 500000
+ gradient_steps: 1
+ gradient_clipping: 0.0
+ gradient_debug: false
+ checkpointer:
+ checkpoint_steps: 500
+ validation_steps: 500
+ loss:
+ crop: null
+ losses:
+ - function: l1_loss
+ multiplier: 1.0
+ - function: ssim_loss
+ multiplier: 1.0
+validation:
+ datasets:
+ # Twice the same dataset but a different acceleration factor
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5]
+ crop_outer_slices: true
+ text_description: 5x # Description for logging
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [10]
+ crop_outer_slices: true
+ text_description: 10x # Description for logging
+ crop: null # This sets the cropping for the DoIterationOutput
+ metrics: # These are obtained from direct.functionals
+ - calgary_campinas_psnr
+ - calgary_campinas_ssim
+ - calgary_campinas_vif
+model:
+ model_name: kikinet.kikinet.KIKINet
+ num_iter: 2
+ image_model_architecture: UNET
+ image_unet_num_filters: 16
+ image_unet_num_pool_layers: 4
+ kspace_model_architecture: UNET
+ kspace_unet_num_filters: 16
+ kspace_unet_num_pool_layers: 4
+
+additional_models:
+ sensitivity_model:
+ model_name: unet.unet_2d.UnetModel2d
+ in_channels: 2
+ out_channels: 2
+ num_filters: 8
+ num_pool_layers: 4
+ dropout_probability: 0.0
+logging:
+ tensorboard:
+ num_images: 4
+inference:
+ batch_size: 8
+ dataset:
+ name: CalgaryCampinas
+ crop_outer_slices: true
+ text_description: inference
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
diff --git a/projects/calgary_campinas/configs/base_multidomainnet.yaml b/projects/calgary_campinas/configs/base_multidomainnet.yaml
new file mode 100644
index 00000000..ec90112d
--- /dev/null
+++ b/projects/calgary_campinas/configs/base_multidomainnet.yaml
@@ -0,0 +1,105 @@
+physics:
+ forward_operator: fft2(centered=False)
+ backward_operator: ifft2(centered=False)
+training:
+ datasets:
+ # Two datasets, only difference is the shape, so the data can be collated for larger batches
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x170_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x180_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ batch_size: 20
+ optimizer: Adam
+ lr: 0.001
+ weight_decay: 0.0
+ lr_step_size: 50000
+ lr_gamma: 0.2
+ lr_warmup_iter: 1000
+ num_iterations: 500000
+ gradient_steps: 1
+ gradient_clipping: 0.0
+ gradient_debug: false
+ checkpointer:
+ checkpoint_steps: 500
+ validation_steps: 500
+ loss:
+ crop: null
+ losses:
+ - function: l1_loss
+ multiplier: 1.0
+ - function: ssim_loss
+ multiplier: 1.0
+validation:
+ datasets:
+ # Twice the same dataset but a different acceleration factor
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5]
+ crop_outer_slices: true
+ text_description: 5x # Description for logging
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [10]
+ crop_outer_slices: true
+ text_description: 10x # Description for logging
+ crop: null # This sets the cropping for the DoIterationOutput
+ metrics: # These are obtained from direct.functionals
+ - calgary_campinas_psnr
+ - calgary_campinas_ssim
+ - calgary_campinas_vif
+model:
+ model_name: multidomainnet.multidomainnet.MultiDomainNet
+ num_filters: 16
+ num_pool_layers: 4
+ dropout_probability: 0.05
+additional_models:
+ sensitivity_model:
+ model_name: unet.unet_2d.UnetModel2d
+ in_channels: 2
+ out_channels: 2
+ num_filters: 16
+ num_pool_layers: 4
+ dropout_probability: 0.05
+logging:
+ tensorboard:
+ num_images: 4
+inference:
+ batch_size: 8
+ dataset:
+ name: CalgaryCampinas
+ crop_outer_slices: true
+ text_description: inference
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
diff --git a/projects/calgary_campinas/configs/base.yaml b/projects/calgary_campinas/configs/base_rim.yaml
similarity index 100%
rename from projects/calgary_campinas/configs/base.yaml
rename to projects/calgary_campinas/configs/base_rim.yaml
diff --git a/projects/calgary_campinas/configs/base_unet.yaml b/projects/calgary_campinas/configs/base_unet.yaml
new file mode 100644
index 00000000..ce76b62b
--- /dev/null
+++ b/projects/calgary_campinas/configs/base_unet.yaml
@@ -0,0 +1,106 @@
+physics:
+ forward_operator: fft2(centered=False)
+ backward_operator: ifft2(centered=False)
+training:
+ datasets:
+ # Two datasets, only difference is the shape, so the data can be collated for larger batches
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x170_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x180_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ batch_size: 32 # This is the batch size per GPU!
+ optimizer: Adam
+ lr: 0.0002
+ weight_decay: 0.0
+ lr_step_size: 50000
+ lr_gamma: 0.2
+ lr_warmup_iter: 3000
+ num_iterations: 500000
+ gradient_steps: 1
+ gradient_clipping: 0.0
+ gradient_debug: false
+ checkpointer:
+ checkpoint_steps: 500
+ validation_steps: 2000
+ loss:
+ crop: null
+ losses:
+ - function: l1_loss
+ multiplier: 1.0
+ - function: ssim_loss
+ multiplier: 1.0
+ - function: l2_loss
+ multiplier: 1.0
+validation:
+ datasets:
+ # Twice the same dataset but a different acceleration factor
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5]
+ crop_outer_slices: true
+ text_description: 5x # Description for logging
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [10]
+ crop_outer_slices: true
+ text_description: 10x # Description for logging
+ crop: null # This sets the cropping for the DoIterationOutput
+ metrics: # These are obtained from direct.functionals
+ - calgary_campinas_psnr
+ - calgary_campinas_ssim
+ - calgary_campinas_vif
+model:
+ model_name: unet.unet_2d.Unet2d
+ num_filters: 64
+ image_initialization: sense
+additional_models:
+ sensitivity_model:
+ model_name: unet.unet_2d.UnetModel2d
+ in_channels: 2
+ out_channels: 2
+ num_filters: 8
+ num_pool_layers: 4
+ dropout_probability: 0.0
+logging:
+ tensorboard:
+ num_images: 4
+inference:
+ batch_size: 8
+ dataset:
+ name: CalgaryCampinas
+ crop_outer_slices: true
+ text_description: inference
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
diff --git a/projects/calgary_campinas/configs/base_varnet.yaml b/projects/calgary_campinas/configs/base_varnet.yaml
new file mode 100644
index 00000000..f325bd17
--- /dev/null
+++ b/projects/calgary_campinas/configs/base_varnet.yaml
@@ -0,0 +1,104 @@
+physics:
+ forward_operator: fft2(centered=False)
+ backward_operator: ifft2(centered=False)
+training:
+ datasets:
+ # Two datasets, only difference is the shape, so the data can be collated for larger batches
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x170_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x180_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ batch_size: 16 # This is the batch size per GPU!
+ optimizer: Adam
+ lr: 0.0005
+ weight_decay: 0.0
+ lr_step_size: 50000
+ lr_gamma: 0.2
+ lr_warmup_iter: 1500
+ num_iterations: 500000
+ gradient_steps: 1
+ gradient_clipping: 0.0
+ gradient_debug: false
+ checkpointer:
+ checkpoint_steps: 500
+ validation_steps: 500
+ loss:
+ crop: null
+ losses:
+ - function: ssim_loss
+ multiplier: 1.0
+ - function: l1_loss
+ multiplier: 1.0
+validation:
+ datasets:
+ # Twice the same dataset but a different acceleration factor
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5]
+ crop_outer_slices: true
+ text_description: 5x # Description for logging
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [10]
+ crop_outer_slices: true
+ text_description: 10x # Description for logging
+ crop: null # This sets the cropping for the DoIterationOutput
+ metrics: # These are obtained from direct.functionals
+ - calgary_campinas_psnr
+ - calgary_campinas_ssim
+ - calgary_campinas_vif
+model:
+ model_name: varnet.varnet.EndToEndVarNet
+ num_layers: 12
+ regularizer_num_filters: 32
+additional_models:
+ sensitivity_model:
+ model_name: unet.unet_2d.UnetModel2d
+ in_channels: 2
+ out_channels: 2
+ num_filters: 8
+ num_pool_layers: 4
+ dropout_probability: 0.0
+logging:
+ tensorboard:
+ num_images: 4
+inference:
+ batch_size: 8
+ dataset:
+ name: CalgaryCampinas
+ crop_outer_slices: true
+ text_description: inference
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
diff --git a/projects/calgary_campinas/configs/base_xpdnet.yaml b/projects/calgary_campinas/configs/base_xpdnet.yaml
new file mode 100644
index 00000000..171d48e6
--- /dev/null
+++ b/projects/calgary_campinas/configs/base_xpdnet.yaml
@@ -0,0 +1,110 @@
+physics:
+ forward_operator: fft2(centered=False)
+ backward_operator: ifft2(centered=False)
+training:
+ datasets:
+ # Two datasets, only difference is the shape, so the data can be collated for larger batches
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x170_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ - name: CalgaryCampinas
+ lists:
+ - ../lists/train/12x218x180_train.lst
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
+ scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
+ image_center_crop: false
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5, 10]
+ crop_outer_slices: true
+ batch_size: 8 # This is the batch size per GPU!
+ optimizer: Adam
+ lr: 0.0002
+ weight_decay: 0.0
+ lr_step_size: 50000
+ lr_gamma: 0.2
+ lr_warmup_iter: 1000
+ num_iterations: 500000
+ gradient_steps: 1
+ gradient_clipping: 0.0
+ gradient_debug: false
+ checkpointer:
+ checkpoint_steps: 500
+ validation_steps: 500
+ loss:
+ crop: null
+ losses:
+ - function: l1_loss
+ multiplier: 1.0
+ - function: ssim_loss
+ multiplier: 1.0
+validation:
+ datasets:
+ # Twice the same dataset but a different acceleration factor
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [5]
+ crop_outer_slices: true
+ text_description: 5x # Description for logging
+ - name: CalgaryCampinas
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
+ masking:
+ name: CalgaryCampinas
+ accelerations: [10]
+ crop_outer_slices: true
+ text_description: 10x # Description for logging
+ crop: null # This sets the cropping for the DoIterationOutput
+ metrics: # These are obtained from direct.functionals
+ - calgary_campinas_psnr
+ - calgary_campinas_ssim
+ - calgary_campinas_vif
+model:
+ model_name: xpdnet.xpdnet.XPDNet
+ num_primal: 5
+ num_iter: 20
+ mwcnn_hidden_channels: 32
+ mwcnn_num_scales: 4
+ mwcnn_bias: True
+ mwcnn_batchnorm: False
+ normalize: False
+
+additional_models:
+ sensitivity_model:
+ model_name: unet.unet_2d.UnetModel2d
+ in_channels: 2
+ out_channels: 2
+ num_filters: 8
+ num_pool_layers: 4
+ dropout_probability: 0.0
+logging:
+ tensorboard:
+ num_images: 4
+inference:
+ batch_size: 8
+ dataset:
+ name: CalgaryCampinas
+ crop_outer_slices: true
+ text_description: inference
+ transforms:
+ crop: null
+ estimate_sensitivity_maps: true
+ scaling_key: masked_kspace
diff --git a/tools/train_rim.py b/tools/train_model.py
similarity index 100%
rename from tools/train_rim.py
rename to tools/train_model.py