Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement new models #228

Merged
merged 28 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
168a1d4
isort fix
georgeyiasemis Oct 18, 2022
7a7e23c
Added models
georgeyiasemis Oct 18, 2022
4c6e65a
Minor fixes, new tests
georgeyiasemis Oct 20, 2022
c4f1754
Tests on cpu
georgeyiasemis Oct 20, 2022
3b56ed8
CG algo fixed, ConjGradNet individual module
georgeyiasemis Oct 21, 2022
d3e2493
Update tests
georgeyiasemis Oct 21, 2022
096c1d5
Update config.py
georgeyiasemis Oct 21, 2022
04ae614
Rename param
georgeyiasemis Oct 21, 2022
1d7a9be
Complex division transforms
georgeyiasemis Oct 22, 2022
78790da
Complex dot product to transforms
georgeyiasemis Oct 23, 2022
eaf82a6
Documentation fix
georgeyiasemis Oct 23, 2022
3f299cb
Refactored and added tests
georgeyiasemis Oct 23, 2022
1b42b0a
Update doc history
georgeyiasemis Oct 23, 2022
a778f9b
Adding conjgradnet CC baseline to zoo
georgeyiasemis Oct 23, 2022
2fbe07a
Baseline cfg
georgeyiasemis Oct 23, 2022
39e8a8c
Pylint codacy fixes
georgeyiasemis Oct 24, 2022
213414e
Docs, pylint
georgeyiasemis Oct 24, 2022
db4b170
Fix CRLF
georgeyiasemis Oct 25, 2022
f461740
Add scale percentile as a param
georgeyiasemis Oct 25, 2022
f1f72c6
Itedualnet baseline
georgeyiasemis Oct 26, 2022
4c0105b
Minor fix
georgeyiasemis Oct 28, 2022
8439910
Docs
georgeyiasemis Oct 28, 2022
e8cd399
Minor docs fix#
georgeyiasemis Oct 28, 2022
dcc30c6
Docs
georgeyiasemis Oct 28, 2022
bb3df1d
No whiteline
georgeyiasemis Oct 28, 2022
302f120
PR fixes: Enum classes, COMPLEX_SIZE introduced
georgeyiasemis Oct 31, 2022
c9af57f
Remove deprecated eig fun
georgeyiasemis Oct 31, 2022
c02fdc9
Minor fix
georgeyiasemis Oct 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions direct/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

COMPLEX_SIZE = 2
1 change: 1 addition & 0 deletions direct/data/datasets_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TransformsConfig(BaseConfig):
image_recon_type: str = "rss"
pad_coils: Optional[int] = None
scaling_key: Optional[str] = "masked_kspace"
scale_percentile: Optional[float] = 0.99
use_seed: bool = True


Expand Down
15 changes: 10 additions & 5 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,11 +946,11 @@ def complex_whiten(self, complex_image: torch.Tensor) -> Tuple[torch.Tensor, tor
eig_input = torch.Tensor([[real_real, real_imag], [real_imag, imag_imag]])

# Remove correlation by rotating around covariance eigenvectors.
eig_values, eig_vecs = torch.eig(eig_input, eigenvectors=True)
eig_values, eig_vecs = torch.linalg.eig(eig_input)

# Scale by eigenvalues for unit variance.
std = (eig_values[:, 0] + self.epsilon).sqrt()
whitened_image = torch.matmul(centered_complex_image, eig_vecs) / std
std = (eig_values.real + self.epsilon).sqrt()
whitened_image = torch.matmul(centered_complex_image, eig_vecs.real) / std

return mean, std, whitened_image

Expand Down Expand Up @@ -1041,6 +1041,7 @@ def build_mri_transforms(
image_recon_type: str = "rss",
pad_coils: Optional[int] = None,
scaling_key: str = "masked_kspace",
scale_percentile: Optional[float] = 0.99,
use_seed: bool = True,
) -> object:
"""Build transforms for MRI.
Expand Down Expand Up @@ -1089,6 +1090,8 @@ def build_mri_transforms(
Number of coils to pad data to.
scaling_key : str
Key in sample to scale scalable items in sample. Default: "masked_kspace".
scale_percentile : float, optional
Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99
georgeyiasemis marked this conversation as resolved.
Show resolved Hide resolved
use_seed : bool
If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
the same mask every time. Default: True.
Expand Down Expand Up @@ -1149,8 +1152,10 @@ def build_mri_transforms(
mri_transforms.append(EstimateBodyCoilImage(mask_func, backward_operator=backward_operator, use_seed=use_seed))

mri_transforms += [
ComputeScalingFactor(normalize_key=scaling_key, percentile=0.99, scaling_factor_key="scaling_factor"),
Normalize(),
ComputeScalingFactor(
normalize_key=scaling_key, percentile=scale_percentile, scaling_factor_key="scaling_factor"
),
Normalize(scaling_factor_key="scaling_factor"),
PadCoilDimension(pad_coils=pad_coils, key="masked_kspace"),
PadCoilDimension(pad_coils=pad_coils, key="sensitivity_map"),
]
Expand Down
58 changes: 57 additions & 1 deletion direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,66 @@ def complex_multiplication(input_tensor: torch.Tensor, other_tensor: torch.Tenso
],
dim=complex_index,
)

return multiplication


def complex_dot_product(a: torch.Tensor, b: torch.Tensor, dim: List[int]) -> torch.Tensor:
georgeyiasemis marked this conversation as resolved.
Show resolved Hide resolved
r"""Computes the dot product of the complex tensors :math:`a` and :math:`b`: :math:`a^{*}b = <a, b>`.

Parameters
----------
a : torch.Tensor
Input :math:`a`.
b : torch.Tensor
Input :math:`b`.
dim : List[int]
Dimensions which will be suppressed. Useful when inputs are batched.

Returns
-------
complex_dot_product : torch.Tensor
Dot product of :math:`a` and :math:`b`.
"""
return complex_multiplication(conjugate(a), b).sum(dim)
georgeyiasemis marked this conversation as resolved.
Show resolved Hide resolved


def complex_division(input_tensor: torch.Tensor, other_tensor: torch.Tensor) -> torch.Tensor:
"""Divides two complex-valued tensors. Assumes input tensors are complex (last axis has dimension 2).

Parameters
----------
input_tensor: torch.Tensor
Input data
other_tensor: torch.Tensor
Input data

Returns
-------
torch.Tensor
"""
assert_complex(input_tensor, complex_last=True)
assert_complex(other_tensor, complex_last=True)

complex_index = -1

denominator = other_tensor[..., 0] ** 2 + other_tensor[..., 1] ** 2
real_part = safe_divide(
input_tensor[..., 0] * other_tensor[..., 0] + input_tensor[..., 1] * other_tensor[..., 1], denominator
)
imaginary_part = safe_divide(
input_tensor[..., 1] * other_tensor[..., 0] - input_tensor[..., 0] * other_tensor[..., 1], denominator
)

division = torch.cat(
[
real_part.unsqueeze(dim=complex_index),
imaginary_part.unsqueeze(dim=complex_index),
],
dim=complex_index,
)
return division
georgeyiasemis marked this conversation as resolved.
Show resolved Hide resolved


def _complex_matrix_multiplication(
input_tensor: torch.Tensor, other_tensor: torch.Tensor, mult_func: Callable
) -> torch.Tensor:
Expand Down
8 changes: 1 addition & 7 deletions direct/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,7 @@
from direct.data.samplers import ConcatDatasetBatchSampler
from direct.exceptions import ProcessKilledException, TrainingException
from direct.types import PathOrString
from direct.utils import (
communication,
normalize_image,
prefix_dict_keys,
reduce_list_of_dicts,
str_to_class,
)
from direct.utils import communication, normalize_image, prefix_dict_keys, reduce_list_of_dicts, str_to_class
from direct.utils.events import CommonMetricPrinter, EventStorage, JSONWriter, TensorboardWriter, get_event_storage
from direct.utils.io import write_json

Expand Down
2 changes: 2 additions & 0 deletions direct/nn/conjgradnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
35 changes: 35 additions & 0 deletions direct/nn/conjgradnet/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

from dataclasses import dataclass
from typing import Optional

from direct.config.defaults import ModelConfig
from direct.nn.types import ActivationType, ModelName
from direct.nn.conjgradnet.conjgrad import CGUpdateType
from direct.nn.conjgradnet.conjgradnet import ConjGradNetInitType


@dataclass
class ConjGradNetConfig(ModelConfig):
num_steps: int = 8
image_init: str = ConjGradNetInitType.zeros
no_parameter_sharing: bool = True
cg_tol: float = 1e-7
cg_iters: int = 10
cg_param_update_type: str = CGUpdateType.FR
denoiser_architecture: str = ModelName.resnet
resnet_hidden_channels: int = 128
resnet_num_blocks: int = 15
resenet_batchnorm: bool = True
resenet_scale: Optional[float] = 0.1
unet_num_filters: Optional[int] = 32
unet_num_pool_layers: Optional[int] = 4
unet_dropout: Optional[float] = 0.0
didn_hidden_channels: Optional[int] = 16
didn_num_dubs: Optional[int] = 6
didn_num_convs_recon: Optional[int] = 9
conv_hidden_channels: Optional[int] = 64
conv_n_convs: Optional[int] = 15
conv_activation: Optional[str] = ActivationType.relu
conv_batchnorm: Optional[bool] = False
Loading