Skip to content

Commit

Permalink
Merge pull request #229 from NKI-AI/main
Browse files Browse the repository at this point in the history
Implement new models (#228)
  • Loading branch information
georgeyiasemis authored Oct 31, 2022
2 parents af564f4 + e7ea67e commit 211748a
Show file tree
Hide file tree
Showing 43 changed files with 2,301 additions and 75 deletions.
4 changes: 4 additions & 0 deletions direct/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

COMPLEX_SIZE = 2
1 change: 1 addition & 0 deletions direct/data/datasets_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TransformsConfig(BaseConfig):
image_recon_type: str = "rss"
pad_coils: Optional[int] = None
scaling_key: Optional[str] = "masked_kspace"
scale_percentile: Optional[float] = 0.99
use_seed: bool = True


Expand Down
15 changes: 10 additions & 5 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,11 +946,11 @@ def complex_whiten(self, complex_image: torch.Tensor) -> Tuple[torch.Tensor, tor
eig_input = torch.Tensor([[real_real, real_imag], [real_imag, imag_imag]])

# Remove correlation by rotating around covariance eigenvectors.
eig_values, eig_vecs = torch.eig(eig_input, eigenvectors=True)
eig_values, eig_vecs = torch.linalg.eig(eig_input)

# Scale by eigenvalues for unit variance.
std = (eig_values[:, 0] + self.epsilon).sqrt()
whitened_image = torch.matmul(centered_complex_image, eig_vecs) / std
std = (eig_values.real + self.epsilon).sqrt()
whitened_image = torch.matmul(centered_complex_image, eig_vecs.real) / std

return mean, std, whitened_image

Expand Down Expand Up @@ -1041,6 +1041,7 @@ def build_mri_transforms(
image_recon_type: str = "rss",
pad_coils: Optional[int] = None,
scaling_key: str = "masked_kspace",
scale_percentile: Optional[float] = 0.99,
use_seed: bool = True,
) -> object:
"""Build transforms for MRI.
Expand Down Expand Up @@ -1089,6 +1090,8 @@ def build_mri_transforms(
Number of coils to pad data to.
scaling_key : str
Key in sample to scale scalable items in sample. Default: "masked_kspace".
scale_percentile : float, optional
Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99
use_seed : bool
If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
the same mask every time. Default: True.
Expand Down Expand Up @@ -1149,8 +1152,10 @@ def build_mri_transforms(
mri_transforms.append(EstimateBodyCoilImage(mask_func, backward_operator=backward_operator, use_seed=use_seed))

mri_transforms += [
ComputeScalingFactor(normalize_key=scaling_key, percentile=0.99, scaling_factor_key="scaling_factor"),
Normalize(),
ComputeScalingFactor(
normalize_key=scaling_key, percentile=scale_percentile, scaling_factor_key="scaling_factor"
),
Normalize(scaling_factor_key="scaling_factor"),
PadCoilDimension(pad_coils=pad_coils, key="masked_kspace"),
PadCoilDimension(pad_coils=pad_coils, key="sensitivity_map"),
]
Expand Down
58 changes: 57 additions & 1 deletion direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,66 @@ def complex_multiplication(input_tensor: torch.Tensor, other_tensor: torch.Tenso
],
dim=complex_index,
)

return multiplication


def complex_dot_product(a: torch.Tensor, b: torch.Tensor, dim: List[int]) -> torch.Tensor:
r"""Computes the dot product of the complex tensors :math:`a` and :math:`b`: :math:`a^{*}b = <a, b>`.
Parameters
----------
a : torch.Tensor
Input :math:`a`.
b : torch.Tensor
Input :math:`b`.
dim : List[int]
Dimensions which will be suppressed. Useful when inputs are batched.
Returns
-------
complex_dot_product : torch.Tensor
Dot product of :math:`a` and :math:`b`.
"""
return complex_multiplication(conjugate(a), b).sum(dim)


def complex_division(input_tensor: torch.Tensor, other_tensor: torch.Tensor) -> torch.Tensor:
"""Divides two complex-valued tensors. Assumes input tensors are complex (last axis has dimension 2).
Parameters
----------
input_tensor: torch.Tensor
Input data
other_tensor: torch.Tensor
Input data
Returns
-------
torch.Tensor
"""
assert_complex(input_tensor, complex_last=True)
assert_complex(other_tensor, complex_last=True)

complex_index = -1

denominator = other_tensor[..., 0] ** 2 + other_tensor[..., 1] ** 2
real_part = safe_divide(
input_tensor[..., 0] * other_tensor[..., 0] + input_tensor[..., 1] * other_tensor[..., 1], denominator
)
imaginary_part = safe_divide(
input_tensor[..., 1] * other_tensor[..., 0] - input_tensor[..., 0] * other_tensor[..., 1], denominator
)

division = torch.cat(
[
real_part.unsqueeze(dim=complex_index),
imaginary_part.unsqueeze(dim=complex_index),
],
dim=complex_index,
)
return division


def _complex_matrix_multiplication(
input_tensor: torch.Tensor, other_tensor: torch.Tensor, mult_func: Callable
) -> torch.Tensor:
Expand Down
2 changes: 2 additions & 0 deletions direct/nn/conjgradnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
35 changes: 35 additions & 0 deletions direct/nn/conjgradnet/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

from dataclasses import dataclass
from typing import Optional

from direct.config.defaults import ModelConfig
from direct.nn.types import ActivationType, ModelName
from direct.nn.conjgradnet.conjgrad import CGUpdateType
from direct.nn.conjgradnet.conjgradnet import ConjGradNetInitType


@dataclass
class ConjGradNetConfig(ModelConfig):
num_steps: int = 8
image_init: str = ConjGradNetInitType.zeros
no_parameter_sharing: bool = True
cg_tol: float = 1e-7
cg_iters: int = 10
cg_param_update_type: str = CGUpdateType.FR
denoiser_architecture: str = ModelName.resnet
resnet_hidden_channels: int = 128
resnet_num_blocks: int = 15
resenet_batchnorm: bool = True
resenet_scale: Optional[float] = 0.1
unet_num_filters: Optional[int] = 32
unet_num_pool_layers: Optional[int] = 4
unet_dropout: Optional[float] = 0.0
didn_hidden_channels: Optional[int] = 16
didn_num_dubs: Optional[int] = 6
didn_num_convs_recon: Optional[int] = 9
conv_hidden_channels: Optional[int] = 64
conv_n_convs: Optional[int] = 15
conv_activation: Optional[str] = ActivationType.relu
conv_batchnorm: Optional[bool] = False
Loading

0 comments on commit 211748a

Please sign in to comment.