Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement recon models #123

Merged
merged 72 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
7202996
Added transforms for varnet & Code fixes
georgeyiasemis Sep 3, 2021
bfa201e
Added NormUnet2d model
georgeyiasemis Sep 3, 2021
7b22fbf
Added VarNet implementation
georgeyiasemis Sep 3, 2021
95492c4
Minor code quality fixes
georgeyiasemis Sep 6, 2021
1fa094e
Merge pull request #117 from directgroup/implement-varnet
georgeyiasemis Sep 6, 2021
1ead985
VarNet engine refinements
georgeyiasemis Sep 7, 2021
e0713e0
VarNet engine refinements
georgeyiasemis Sep 7, 2021
9efd45a
Added Unet model
georgeyiasemis Sep 7, 2021
a4f600b
Minor fixes
georgeyiasemis Sep 9, 2021
8396472
Fixed typo
georgeyiasemis Sep 20, 2021
8ad63d5
Added transforms for varnet & Code fixes
georgeyiasemis Sep 3, 2021
9eac5dc
Added NormUnet2d model
georgeyiasemis Sep 3, 2021
2cc2401
Added VarNet implementation
georgeyiasemis Sep 3, 2021
5688fd8
Minor code quality fixes
georgeyiasemis Sep 6, 2021
f04c7bf
VarNet engine refinements
georgeyiasemis Sep 7, 2021
c955f84
Fixed typo
georgeyiasemis Sep 20, 2021
b341c90
Added initial xpdnet implementation
georgeyiasemis Sep 20, 2021
70459b8
MWCNN from scratch
georgeyiasemis Sep 22, 2021
5c07761
Black fixes
georgeyiasemis Sep 22, 2021
7e5bf37
Black fixes
georgeyiasemis Sep 22, 2021
a77083a
XPDNet refinements
georgeyiasemis Sep 24, 2021
9c163ea
Added LPDNet implementation
georgeyiasemis Sep 24, 2021
37bf359
Fix Varnet to work with input channels more than 2
georgeyiasemis Sep 24, 2021
4334d2b
XPDNet refinements
georgeyiasemis Sep 24, 2021
89d265f
XPDNet refinements
georgeyiasemis Sep 24, 2021
9be039b
Added DIDN implementation
georgeyiasemis Sep 26, 2021
e31e854
Minor _backward_operator fix
georgeyiasemis Sep 26, 2021
3f68f81
Implemented more kspace models
georgeyiasemis Sep 26, 2021
b9b7dc4
XPDnet works with more kspace models
georgeyiasemis Sep 26, 2021
a4e3e5e
Minor code improvements
georgeyiasemis Sep 27, 2021
0a8c07e
Added JointICNet implementation
georgeyiasemis Sep 28, 2021
59c2137
JointICNet changes
georgeyiasemis Sep 29, 2021
c28db14
Minor scaling fix
georgeyiasemis Sep 30, 2021
da9642c
MultiCoil minor changes
georgeyiasemis Oct 1, 2021
f83a47c
Added kikinet implementation
georgeyiasemis Oct 1, 2021
c4a250e
Minor change
georgeyiasemis Oct 1, 2021
521c3b7
Added modified airs model
georgeyiasemis Oct 1, 2021
05d5dc2
Add baseline scripts
georgeyiasemis Oct 3, 2021
c37ebfe
Merge branch 'implement-cross-domain-nets' into implement-varnet
georgeyiasemis Oct 4, 2021
9279eb8
Merge pull request #121 from directgroup/implement-varnet
georgeyiasemis Oct 4, 2021
597cdf9
Merge branch 'implement-recon-models' into implement-unet-baseline
georgeyiasemis Oct 5, 2021
36f4fa0
Merge pull request #122 from directgroup/implement-unet-baseline
georgeyiasemis Oct 5, 2021
dfa0729
Minor changes
Oct 5, 2021
4b43bad
Conv2dGRU is now in different directory
Oct 5, 2021
3dd2004
RIM fixed to work with Conv2dGRU moving
Oct 5, 2021
bd89f92
Missing attribute fix
Oct 5, 2021
028565c
Added unet baseline config
Oct 5, 2021
4aaf108
Transforms documentation update
Oct 5, 2021
a13f565
Fix minor issue
Oct 6, 2021
694ae41
Fix groupnorm implementation in NormUnet2d
Oct 6, 2021
edb95be
Fix autograd bug
Oct 6, 2021
b9c6a3c
LPD configuration
Oct 6, 2021
1424355
Fix bug in modulus_if_complex transform
Oct 8, 2021
26eef4a
Fix to show val metrics in Tensorboard
Oct 8, 2021
3d6808e
VarNet config fix
Oct 11, 2021
8b612a5
Minor fixes in lpdnet
Oct 11, 2021
5448450
Added transform tests
Oct 11, 2021
ed5ddda
Minor code fixed and add tests
Oct 12, 2021
7629af6
Remove unused import
Oct 12, 2021
6411b33
Update README.md
georgeyiasemis Oct 12, 2021
23a40f7
Minor fix
georgeyiasemis Oct 12, 2021
122a3e0
Minor fix
georgeyiasemis Oct 12, 2021
3167544
Minor code fixes
georgeyiasemis Oct 14, 2021
5d3d058
Add nn tests
georgeyiasemis Oct 14, 2021
55cd4e9
Fix & suppress pylint false positive errors
georgeyiasemis Oct 15, 2021
ac60ada
Numpy docstrings, removed whitespaces, code formatting
georgeyiasemis Oct 16, 2021
44ccfc3
Numpy docstrings, removed whitespaces, code formatting
georgeyiasemis Oct 16, 2021
771b01d
Numpy docstrings, removed whitespaces, code formatting
georgeyiasemis Oct 16, 2021
281ed60
Black fix
georgeyiasemis Oct 17, 2021
713cb8d
Merge pull request #124 from directgroup/main
georgeyiasemis Oct 17, 2021
e7cb662
Remove blanks
georgeyiasemis Oct 18, 2021
3fa6edf
Merge branch 'implement-recon-models' of github.com:directgroup/direc…
georgeyiasemis Oct 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![black](https://github.com/directgroup/direct/actions/workflows/black.yml/badge.svg)](https://github.com/directgroup/direct/actions/workflows/black.yml)

# DIRECT
DIRECT is the Deep Image REConstruction Toolkit that implements state-of-the-art inverse problem solvers. It includes
DIRECT is the Deep Image REConstruction Toolkit that implements state-of-the-art inverse problem solvers. It stores
inverse problem solvers such as the Learned Primal Dual algorithm and Recurrent Inference Machine, which were part of the winning solution in Facebook & NYUs FastMRI challenge in 2019 and the Calgary-Campinas MRI reconstruction challenge at MIDL 2020.

<div align="center">
Expand All @@ -29,9 +29,9 @@ If you use DIRECT in your own research, or want to refer to baseline results pub

```BibTeX
@misc{DIRECTTOOLKIT,
author = {Jonas Teuwen, Nikita Moriakov, Dimitrios Karkalousos, Matthan Caan, George Yiasemis},
title = {DIRECT},
author = {Teuwen, Jonas and Yiasemis, George and Moriakov, Nikita and Karkalousos, Dimitrios and Caan, Matthan},
title = {DIRECT: Deep Image REConstruction Toolkit},
howpublished = {\url{https://github.com/directgroup/direct}},
year = {2020}
year = {2021}
}
```
2 changes: 1 addition & 1 deletion direct/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def parse_filenames_data(self, filenames, extra_h5s=None, filter_slice=None):
if len(filenames) < 5 or idx % (len(filenames) // 5) == 0 or len(filenames) == (idx + 1):
self.logger.info(f"Parsing: {(idx + 1) / len(filenames) * 100:.2f}%.")
try:
kspace = h5py.File(filename, "r")["kspace"]
kspace = np.array(h5py.File(filename, "r")["kspace"])
self.verify_extra_h5_integrity(filename, kspace.shape, extra_h5s=extra_h5s)

except OSError as exc:
Expand Down
File renamed without changes.
69 changes: 68 additions & 1 deletion direct/data/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ def test_modulus(shape):
assert np.allclose(out_torch, out_numpy)


@pytest.mark.parametrize(
"shape",
[
[3, 3],
[4, 6],
[10, 8, 4],
[3, 4, 3, 5],
],
)
@pytest.mark.parametrize("complex", [True, False])
def test_modulus_if_complex(shape, complex):
if complex:
shape += [
2,
]
test_modulus(shape)


@pytest.mark.parametrize(
"shape, dims",
[
Expand Down Expand Up @@ -166,7 +184,6 @@ def test_center_crop(shape, target_shape):
[[8, 4], [4, 4]],
],
)
# @pytest.mark.parametrize("named", [True, False])
def test_complex_center_crop(shape, target_shape):
shape = shape + [2]
input = create_input(shape)
Expand Down Expand Up @@ -262,3 +279,53 @@ def test_ifftshift(shape):
out_torch = transforms.ifftshift(torch_tensor).numpy()
out_numpy = np.fft.ifftshift(data)
assert np.allclose(out_torch, out_numpy)


@pytest.mark.parametrize(
"shape, dim",
[
[[3, 4, 5], 0],
[[3, 3, 4, 5], 1],
[[3, 6, 4, 5], 0],
[[3, 3, 6, 4, 5], 1],
],
)
def test_expand_operator(shape, dim):
shape = shape + [
2,
]
data = create_input(shape) # noqa
shape = shape[:dim] + shape[dim + 1 :]
sens = create_input(shape) # noqa

out_torch = tensor_to_complex_numpy(transforms.expand_operator(data, sens, dim))

input_numpy = np.expand_dims(tensor_to_complex_numpy(data), dim)
input_sens_numpy = tensor_to_complex_numpy(sens)
out_numpy = input_sens_numpy * input_numpy

assert np.allclose(out_torch, out_numpy)


@pytest.mark.parametrize(
"shape, dim",
[
[[3, 4, 5], 0],
[[3, 3, 4, 5], 1],
[[3, 6, 4, 5], 0],
[[3, 3, 6, 4, 5], 1],
],
)
def test_reduce_operator(shape, dim):
shape = shape + [
2,
]
coil_data = create_input(shape) # noqa
sens = create_input(shape) # noqa
out_torch = tensor_to_complex_numpy(transforms.reduce_operator(coil_data, sens, dim))

input_numpy = tensor_to_complex_numpy(coil_data)
input_sens_numpy = tensor_to_complex_numpy(sens)
out_numpy = (input_sens_numpy.conj() * input_numpy).sum(dim)

assert np.allclose(out_torch, out_numpy)
108 changes: 83 additions & 25 deletions direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
import numpy as np
import torch
import torch.fft
from packaging import version

from direct.data.bbox import crop_to_bbox
from direct.utils import ensure_list, is_power_of_two
from direct.utils import ensure_list, is_power_of_two, is_complex_data
from direct.utils.asserts import assert_complex, assert_same_shape


Expand Down Expand Up @@ -222,7 +221,6 @@ def safe_divide(input_tensor: torch.Tensor, other_tensor: torch.Tensor) -> torch
torch.tensor([0.0], dtype=input_tensor.dtype).to(input_tensor.device),
input_tensor / other_tensor,
)

return data


Expand Down Expand Up @@ -266,8 +264,7 @@ def align_as(input_tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
input_shape = list(input_tensor.shape)
other_shape = torch.tensor(other.shape, dtype=int)
out_shape = torch.ones(len(other.shape), dtype=int)
# TODO(gy): Fix to ensure complex_last when [2,..., 2] or [..., N,..., N,...] in other.shape,
# "-input_shape.count(dim):" is a hack and might cause problems.

for dim in np.sort(np.unique(input_tensor.shape)):
ind = torch.where(other_shape == dim)[0][-input_shape.count(dim) :]
out_shape[ind] = dim
Expand All @@ -292,7 +289,6 @@ def modulus(data: torch.Tensor) -> torch.Tensor:
complex_axis = -1 if data.size(-1) == 2 else 1

return (data ** 2).sum(complex_axis).sqrt() # noqa
# return torch.view_as_complex(data).abs()


def modulus_if_complex(data: torch.Tensor) -> torch.Tensor:
Expand All @@ -307,11 +303,9 @@ def modulus_if_complex(data: torch.Tensor) -> torch.Tensor:
-------
torch.Tensor
"""
# TODO: This can be merged with modulus if the tensor is real.
try:
if is_complex_data(data, complex_last=False):
return modulus(data)
except ValueError:
return data
return data


def roll(
Expand Down Expand Up @@ -436,7 +430,7 @@ def _complex_matrix_multiplication(input_tensor, other_tensor, mult_func):

Parameters
----------
x : torch.Tensor
input_tensor : torch.Tensor
other_tensor : torch.Tensor
mult_func : Callable
Multiplication function e.g. torch.bmm or torch.mm
Expand Down Expand Up @@ -466,6 +460,7 @@ def complex_mm(input_tensor, other_tensor):
----------
input_tensor : torch.Tensor
other_tensor : torch.Tensor

Returns
-------
torch.Tensor
Expand Down Expand Up @@ -501,10 +496,6 @@ def conjugate(data: torch.Tensor) -> torch.Tensor:
-------
torch.Tensor
"""
# assert_complex(data, complex_last=True)
# data = torch.view_as_real(
# torch.view_as_complex(data).conj()
# )
assert_complex(data, complex_last=True)
data = data.clone() # Clone is required as the data in the next line is changed in-place.
data[..., 1] = data[..., 1] * -1.0
Expand Down Expand Up @@ -574,11 +565,12 @@ def tensor_to_complex_numpy(data: torch.Tensor) -> np.ndarray:
return data[..., 0] + 1j * data[..., 1]


def root_sum_of_squares(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
r"""
def root_sum_of_squares(data: torch.Tensor, dim: int = 0, complex_dim: int = -1) -> torch.Tensor:
"""
Compute the root sum of squares (RSS) transform along a given dimension of the input tensor.

$$x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}$$
.. math::
x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2}

Parameters
----------
Expand All @@ -588,16 +580,16 @@ def root_sum_of_squares(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
dim : int
Coil dimension. Default is 0 as the first dimension is always the coil dimension.

complex_dim : int
Complex channel dimension. Default is -1. If data not complex this is ignored.
Returns
-------
torch.Tensor : RSS of the input tensor.
"""
try:
assert_complex(data, complex_last=True)
complex_index = -1
return torch.sqrt((data ** 2).sum(complex_index).sum(dim))
except ValueError:
return torch.sqrt((data ** 2).sum(dim))
if is_complex_data(data):
return torch.sqrt((data ** 2).sum(complex_dim).sum(dim))

return torch.sqrt((data ** 2).sum(dim))


def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor:
Expand Down Expand Up @@ -760,5 +752,71 @@ def complex_random_crop(

if len(output) == 1:
return output[0]
georgeyiasemis marked this conversation as resolved.
Show resolved Hide resolved

return output


def reduce_operator(
coil_data: torch.Tensor,
sensitivity_map: torch.Tensor,
dim: int = 0,
) -> torch.Tensor:
"""
Given zero-filled reconstructions from multiple coils :math: \{x_i\}_{i=1}^{N_c} and coil sensitivity maps
:math: \{S_i\}_{i=1}^{N_c} it returns
.. math::
R(x_1, .., x_{N_c}, S_1, .., S_{N_c}) = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i.

From paper End-to-End Variational Networks for Accelerated MRI Reconstruction.

Parameters
----------
coil_data : torch.Tensor
Zero-filled reconstructions from coils. Should be a complex tensor (with complex dim of size 2).
sensitivity_map: torch.Tensor
Coil sensitivity maps. Should be complex tensor (with complex dim of size 2).
dim: int
Coil dimension. Default: 0.

Returns
-------
torch.Tensor:
Combined individual coil images.
"""

assert_complex(coil_data, complex_last=True)
assert_complex(sensitivity_map, complex_last=True)
Comment on lines +786 to +787
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this assert_complex still needed? Can't our internal representation completely switch to complex tensors, and use .is_complex() or so (that can be done in the assert_complex of course)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will require to change the whole software & models so let's keep this as a future change


return complex_multiplication(conjugate(sensitivity_map), coil_data).sum(dim)


def expand_operator(
data: torch.Tensor,
sensitivity_map: torch.Tensor,
dim: int = 0,
) -> torch.Tensor:
"""
Given a reconstructed image x and coil sensitivity maps :math: \{S_i\}_{i=1}^{N_c}, it returns
.. math::
\Epsilon(x) = (S_1 \times x, .., S_{N_c} \times x) = (x_1, .., x_{N_c}).

From paper End-to-End Variational Networks for Accelerated MRI Reconstruction.

Parameters
----------
data : torch.Tensor
Image data. Should be a complex tensor (with complex dim of size 2).
sensitivity_map: torch.Tensor
Coil sensitivity maps. Should be complex tensor (with complex dim of size 2).
dim: int
Coil dimension. Default: 0.

Returns
-------
torch.Tensor:
Zero-filled reconstructions from each coil.
"""

assert_complex(data, complex_last=True)
assert_complex(sensitivity_map, complex_last=True)

return complex_multiplication(sensitivity_map, data.unsqueeze(dim))
9 changes: 1 addition & 8 deletions direct/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import warnings
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import Callable, Dict, List, Optional, TypedDict, Union
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -278,13 +278,6 @@ def training_loop(
fail_counter = 0
for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):

# 2D data is batched and contains keys:
# "filename_slice", "slice_no"
# "sampling_mask" of shape: (batch, 1, height, width, 1)
# "sensitivity_map" of shape: (batch, coil, height, width, complex=2)
# "target" of shape: (batch, height, width)
# "masked_kspace" of shape: (batch, coil, height, width, complex=2)

Comment on lines -281 to -287
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels we should write this down somewhere. Maybe in the documentation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation of engine?

if iter_idx == 0:
self.log_first_training_example_and_model(data)

Expand Down
11 changes: 5 additions & 6 deletions direct/environment.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

# pylint: disable = E1101

import argparse
import logging
import os
Expand All @@ -17,6 +15,7 @@
from torch.utils import collect_env

import direct.utils.logging
from direct.utils.logging import setup
from direct.config.defaults import DefaultConfig, InferenceConfig, TrainingConfig, ValidationConfig
from direct.utils import communication, count_parameters, str_to_class

Expand Down Expand Up @@ -74,7 +73,7 @@ def setup_logging(machine_rank, output_directory, run_name, cfg_filename, cfg, d
# Setup logging
log_file = output_directory / f"log_{machine_rank}_{communication.get_local_rank()}.txt"

direct.utils.logging.setup(
setup(
use_stdout=communication.is_main_process() or debug,
filename=log_file,
log_level=("INFO" if not debug else "DEBUG"),
Expand Down Expand Up @@ -245,13 +244,13 @@ def setup_common_environment(
dataset_cfg_from_file = extract_names(cfg_from_file[key].datasets)
for idx, (dataset_name, dataset_config) in enumerate(dataset_cfg_from_file):
cfg_from_file_new[key].datasets[idx] = dataset_config
cfg[key].datasets.append(load_dataset_config(dataset_name))
cfg[key].datasets.append(load_dataset_config(dataset_name)) # pylint: disable = E1136
else:
dataset_name, dataset_config = extract_names(cfg_from_file[key].dataset)
cfg_from_file_new[key].dataset = dataset_config
cfg[key].dataset = load_dataset_config(dataset_name)
cfg[key].dataset = load_dataset_config(dataset_name) # pylint: disable = E1136

cfg[key] = OmegaConf.merge(cfg[key], cfg_from_file_new[key])
cfg[key] = OmegaConf.merge(cfg[key], cfg_from_file_new[key]) # pylint: disable = E1136, E1137
# sys.exit()
# Make configuration read only.
# TODO(jt): Does not work when indexing config lists.
Expand Down
6 changes: 3 additions & 3 deletions direct/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ def _distributed_worker(
if communication._LOCAL_PROCESS_GROUP is not None:
raise RuntimeError
num_machines = world_size // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
for idx in range(num_machines):
ranks_on_i = list(range(idx * num_gpus_per_machine, (idx + 1) * num_gpus_per_machine))
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
if idx == machine_rank:
communication._LOCAL_PROCESS_GROUP = pg

main_func(*args)
Expand Down
2 changes: 2 additions & 0 deletions direct/nn/conv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
Loading