Skip to content

Commit

Permalink
Implement recon models (#123)
Browse files Browse the repository at this point in the history
Added new models

Reconstruction Models added:
- Unet for reconstruction
- EndToEndVarNet
- KIKINet
- LPDNet
- XPDNet
- JointICNet
- AIRS model

Additional models added:
- MWCNN
- DIDN
- NormUnet
- Crossdomain
- ConvGRU is now independent of RIM
  • Loading branch information
georgeyiasemis authored Oct 18, 2021
1 parent 961989b commit 312bea6
Show file tree
Hide file tree
Showing 87 changed files with 7,844 additions and 382 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![black](https://github.com/directgroup/direct/actions/workflows/black.yml/badge.svg)](https://github.com/directgroup/direct/actions/workflows/black.yml)

# DIRECT
DIRECT is the Deep Image REConstruction Toolkit that implements state-of-the-art inverse problem solvers. It includes
DIRECT is the Deep Image REConstruction Toolkit that implements state-of-the-art inverse problem solvers. It stores
inverse problem solvers such as the Learned Primal Dual algorithm and Recurrent Inference Machine, which were part of the winning solution in Facebook & NYUs FastMRI challenge in 2019 and the Calgary-Campinas MRI reconstruction challenge at MIDL 2020.

<div align="center">
Expand All @@ -29,9 +29,9 @@ If you use DIRECT in your own research, or want to refer to baseline results pub

```BibTeX
@misc{DIRECTTOOLKIT,
author = {Jonas Teuwen, Nikita Moriakov, Dimitrios Karkalousos, Matthan Caan, George Yiasemis},
title = {DIRECT},
author = {Teuwen, Jonas and Yiasemis, George and Moriakov, Nikita and Karkalousos, Dimitrios and Caan, Matthan},
title = {DIRECT: Deep Image REConstruction Toolkit},
howpublished = {\url{https://github.com/directgroup/direct}},
year = {2020}
year = {2021}
}
```
2 changes: 1 addition & 1 deletion direct/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def parse_filenames_data(self, filenames, extra_h5s=None, filter_slice=None):
if len(filenames) < 5 or idx % (len(filenames) // 5) == 0 or len(filenames) == (idx + 1):
self.logger.info(f"Parsing: {(idx + 1) / len(filenames) * 100:.2f}%.")
try:
kspace = h5py.File(filename, "r")["kspace"]
kspace = np.array(h5py.File(filename, "r")["kspace"])
self.verify_extra_h5_integrity(filename, kspace.shape, extra_h5s=extra_h5s)

except OSError as exc:
Expand Down
File renamed without changes.
69 changes: 68 additions & 1 deletion direct/data/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ def test_modulus(shape):
assert np.allclose(out_torch, out_numpy)


@pytest.mark.parametrize(
"shape",
[
[3, 3],
[4, 6],
[10, 8, 4],
[3, 4, 3, 5],
],
)
@pytest.mark.parametrize("complex", [True, False])
def test_modulus_if_complex(shape, complex):
if complex:
shape += [
2,
]
test_modulus(shape)


@pytest.mark.parametrize(
"shape, dims",
[
Expand Down Expand Up @@ -166,7 +184,6 @@ def test_center_crop(shape, target_shape):
[[8, 4], [4, 4]],
],
)
# @pytest.mark.parametrize("named", [True, False])
def test_complex_center_crop(shape, target_shape):
shape = shape + [2]
input = create_input(shape)
Expand Down Expand Up @@ -262,3 +279,53 @@ def test_ifftshift(shape):
out_torch = transforms.ifftshift(torch_tensor).numpy()
out_numpy = np.fft.ifftshift(data)
assert np.allclose(out_torch, out_numpy)


@pytest.mark.parametrize(
"shape, dim",
[
[[3, 4, 5], 0],
[[3, 3, 4, 5], 1],
[[3, 6, 4, 5], 0],
[[3, 3, 6, 4, 5], 1],
],
)
def test_expand_operator(shape, dim):
shape = shape + [
2,
]
data = create_input(shape) # noqa
shape = shape[:dim] + shape[dim + 1 :]
sens = create_input(shape) # noqa

out_torch = tensor_to_complex_numpy(transforms.expand_operator(data, sens, dim))

input_numpy = np.expand_dims(tensor_to_complex_numpy(data), dim)
input_sens_numpy = tensor_to_complex_numpy(sens)
out_numpy = input_sens_numpy * input_numpy

assert np.allclose(out_torch, out_numpy)


@pytest.mark.parametrize(
"shape, dim",
[
[[3, 4, 5], 0],
[[3, 3, 4, 5], 1],
[[3, 6, 4, 5], 0],
[[3, 3, 6, 4, 5], 1],
],
)
def test_reduce_operator(shape, dim):
shape = shape + [
2,
]
coil_data = create_input(shape) # noqa
sens = create_input(shape) # noqa
out_torch = tensor_to_complex_numpy(transforms.reduce_operator(coil_data, sens, dim))

input_numpy = tensor_to_complex_numpy(coil_data)
input_sens_numpy = tensor_to_complex_numpy(sens)
out_numpy = (input_sens_numpy.conj() * input_numpy).sum(dim)

assert np.allclose(out_torch, out_numpy)
108 changes: 83 additions & 25 deletions direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
import numpy as np
import torch
import torch.fft
from packaging import version

from direct.data.bbox import crop_to_bbox
from direct.utils import ensure_list, is_power_of_two
from direct.utils import ensure_list, is_power_of_two, is_complex_data
from direct.utils.asserts import assert_complex, assert_same_shape


Expand Down Expand Up @@ -222,7 +221,6 @@ def safe_divide(input_tensor: torch.Tensor, other_tensor: torch.Tensor) -> torch
torch.tensor([0.0], dtype=input_tensor.dtype).to(input_tensor.device),
input_tensor / other_tensor,
)

return data


Expand Down Expand Up @@ -266,8 +264,7 @@ def align_as(input_tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
input_shape = list(input_tensor.shape)
other_shape = torch.tensor(other.shape, dtype=int)
out_shape = torch.ones(len(other.shape), dtype=int)
# TODO(gy): Fix to ensure complex_last when [2,..., 2] or [..., N,..., N,...] in other.shape,
# "-input_shape.count(dim):" is a hack and might cause problems.

for dim in np.sort(np.unique(input_tensor.shape)):
ind = torch.where(other_shape == dim)[0][-input_shape.count(dim) :]
out_shape[ind] = dim
Expand All @@ -292,7 +289,6 @@ def modulus(data: torch.Tensor) -> torch.Tensor:
complex_axis = -1 if data.size(-1) == 2 else 1

return (data ** 2).sum(complex_axis).sqrt() # noqa
# return torch.view_as_complex(data).abs()


def modulus_if_complex(data: torch.Tensor) -> torch.Tensor:
Expand All @@ -307,11 +303,9 @@ def modulus_if_complex(data: torch.Tensor) -> torch.Tensor:
-------
torch.Tensor
"""
# TODO: This can be merged with modulus if the tensor is real.
try:
if is_complex_data(data, complex_last=False):
return modulus(data)
except ValueError:
return data
return data


def roll(
Expand Down Expand Up @@ -436,7 +430,7 @@ def _complex_matrix_multiplication(input_tensor, other_tensor, mult_func):
Parameters
----------
x : torch.Tensor
input_tensor : torch.Tensor
other_tensor : torch.Tensor
mult_func : Callable
Multiplication function e.g. torch.bmm or torch.mm
Expand Down Expand Up @@ -466,6 +460,7 @@ def complex_mm(input_tensor, other_tensor):
----------
input_tensor : torch.Tensor
other_tensor : torch.Tensor
Returns
-------
torch.Tensor
Expand Down Expand Up @@ -501,10 +496,6 @@ def conjugate(data: torch.Tensor) -> torch.Tensor:
-------
torch.Tensor
"""
# assert_complex(data, complex_last=True)
# data = torch.view_as_real(
# torch.view_as_complex(data).conj()
# )
assert_complex(data, complex_last=True)
data = data.clone() # Clone is required as the data in the next line is changed in-place.
data[..., 1] = data[..., 1] * -1.0
Expand Down Expand Up @@ -574,11 +565,12 @@ def tensor_to_complex_numpy(data: torch.Tensor) -> np.ndarray:
return data[..., 0] + 1j * data[..., 1]


def root_sum_of_squares(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
r"""
def root_sum_of_squares(data: torch.Tensor, dim: int = 0, complex_dim: int = -1) -> torch.Tensor:
"""
Compute the root sum of squares (RSS) transform along a given dimension of the input tensor.
$$x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}$$
.. math::
x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}
Parameters
----------
Expand All @@ -588,16 +580,16 @@ def root_sum_of_squares(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
dim : int
Coil dimension. Default is 0 as the first dimension is always the coil dimension.
complex_dim : int
Complex channel dimension. Default is -1. If data not complex this is ignored.
Returns
-------
torch.Tensor : RSS of the input tensor.
"""
try:
assert_complex(data, complex_last=True)
complex_index = -1
return torch.sqrt((data ** 2).sum(complex_index).sum(dim))
except ValueError:
return torch.sqrt((data ** 2).sum(dim))
if is_complex_data(data):
return torch.sqrt((data ** 2).sum(complex_dim).sum(dim))

return torch.sqrt((data ** 2).sum(dim))


def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor:
Expand Down Expand Up @@ -760,5 +752,71 @@ def complex_random_crop(

if len(output) == 1:
return output[0]

return output


def reduce_operator(
coil_data: torch.Tensor,
sensitivity_map: torch.Tensor,
dim: int = 0,
) -> torch.Tensor:
"""
Given zero-filled reconstructions from multiple coils :math: \{x_i\}_{i=1}^{N_c} and coil sensitivity maps
:math: \{S_i\}_{i=1}^{N_c} it returns
.. math::
R(x_1, .., x_{N_c}, S_1, .., S_{N_c}) = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i.
From paper End-to-End Variational Networks for Accelerated MRI Reconstruction.
Parameters
----------
coil_data : torch.Tensor
Zero-filled reconstructions from coils. Should be a complex tensor (with complex dim of size 2).
sensitivity_map: torch.Tensor
Coil sensitivity maps. Should be complex tensor (with complex dim of size 2).
dim: int
Coil dimension. Default: 0.
Returns
-------
torch.Tensor:
Combined individual coil images.
"""

assert_complex(coil_data, complex_last=True)
assert_complex(sensitivity_map, complex_last=True)

return complex_multiplication(conjugate(sensitivity_map), coil_data).sum(dim)


def expand_operator(
data: torch.Tensor,
sensitivity_map: torch.Tensor,
dim: int = 0,
) -> torch.Tensor:
"""
Given a reconstructed image x and coil sensitivity maps :math: \{S_i\}_{i=1}^{N_c}, it returns
.. math::
\Epsilon(x) = (S_1 \times x, .., S_{N_c} \times x) = (x_1, .., x_{N_c}).
From paper End-to-End Variational Networks for Accelerated MRI Reconstruction.
Parameters
----------
data : torch.Tensor
Image data. Should be a complex tensor (with complex dim of size 2).
sensitivity_map: torch.Tensor
Coil sensitivity maps. Should be complex tensor (with complex dim of size 2).
dim: int
Coil dimension. Default: 0.
Returns
-------
torch.Tensor:
Zero-filled reconstructions from each coil.
"""

assert_complex(data, complex_last=True)
assert_complex(sensitivity_map, complex_last=True)

return complex_multiplication(sensitivity_map, data.unsqueeze(dim))
9 changes: 1 addition & 8 deletions direct/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import warnings
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import Callable, Dict, List, Optional, TypedDict, Union
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -278,13 +278,6 @@ def training_loop(
fail_counter = 0
for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):

# 2D data is batched and contains keys:
# "filename_slice", "slice_no"
# "sampling_mask" of shape: (batch, 1, height, width, 1)
# "sensitivity_map" of shape: (batch, coil, height, width, complex=2)
# "target" of shape: (batch, height, width)
# "masked_kspace" of shape: (batch, coil, height, width, complex=2)

if iter_idx == 0:
self.log_first_training_example_and_model(data)

Expand Down
11 changes: 5 additions & 6 deletions direct/environment.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

# pylint: disable = E1101

import argparse
import logging
import os
Expand All @@ -17,6 +15,7 @@
from torch.utils import collect_env

import direct.utils.logging
from direct.utils.logging import setup
from direct.config.defaults import DefaultConfig, InferenceConfig, TrainingConfig, ValidationConfig
from direct.utils import communication, count_parameters, str_to_class

Expand Down Expand Up @@ -74,7 +73,7 @@ def setup_logging(machine_rank, output_directory, run_name, cfg_filename, cfg, d
# Setup logging
log_file = output_directory / f"log_{machine_rank}_{communication.get_local_rank()}.txt"

direct.utils.logging.setup(
setup(
use_stdout=communication.is_main_process() or debug,
filename=log_file,
log_level=("INFO" if not debug else "DEBUG"),
Expand Down Expand Up @@ -245,13 +244,13 @@ def setup_common_environment(
dataset_cfg_from_file = extract_names(cfg_from_file[key].datasets)
for idx, (dataset_name, dataset_config) in enumerate(dataset_cfg_from_file):
cfg_from_file_new[key].datasets[idx] = dataset_config
cfg[key].datasets.append(load_dataset_config(dataset_name))
cfg[key].datasets.append(load_dataset_config(dataset_name)) # pylint: disable = E1136
else:
dataset_name, dataset_config = extract_names(cfg_from_file[key].dataset)
cfg_from_file_new[key].dataset = dataset_config
cfg[key].dataset = load_dataset_config(dataset_name)
cfg[key].dataset = load_dataset_config(dataset_name) # pylint: disable = E1136

cfg[key] = OmegaConf.merge(cfg[key], cfg_from_file_new[key])
cfg[key] = OmegaConf.merge(cfg[key], cfg_from_file_new[key]) # pylint: disable = E1136, E1137
# sys.exit()
# Make configuration read only.
# TODO(jt): Does not work when indexing config lists.
Expand Down
6 changes: 3 additions & 3 deletions direct/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ def _distributed_worker(
if communication._LOCAL_PROCESS_GROUP is not None:
raise RuntimeError
num_machines = world_size // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
for idx in range(num_machines):
ranks_on_i = list(range(idx * num_gpus_per_machine, (idx + 1) * num_gpus_per_machine))
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
if idx == machine_rank:
communication._LOCAL_PROCESS_GROUP = pg

main_func(*args)
Expand Down
2 changes: 2 additions & 0 deletions direct/nn/conv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
Loading

0 comments on commit 312bea6

Please sign in to comment.