Skip to content

Commit

Permalink
New loss functions, refactored engines to only implement forward_meth…
Browse files Browse the repository at this point in the history
…od, quality fixes (#226) (Closes #225)

* New loss functions (`NMSE`, `NRMSE`, `NMAE`, `SobelGradL1Loss`, `SobelGradL2Loss`)
* `mri_models` performs `_do_iteration method`, child engines perform `forward_function` which returns output_image and/or output_kspace
* Changes/Additions in `mri_transforms`
    * Padding computed as a Tensor with `ComputePadding` transform (this is helpful when cropping image and tranforming to kspace)
    * `ApplyPadding` transform
    * ComputeImage transform choices of mod output or not
    * `RenameKeys` transform
    * `Normalize` split to `ComputeScalingFactor` and `Normalize`
 * Some quality changes
 * Some documentation changes
  • Loading branch information
georgeyiasemis authored Oct 18, 2022
1 parent 1a8ddf5 commit 520edd4
Show file tree
Hide file tree
Showing 37 changed files with 1,446 additions and 926 deletions.
17 changes: 9 additions & 8 deletions direct/common/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
)

@staticmethod
def center_mask_func(num_cols, num_low_freqs):
def center_mask_func(num_cols: int, num_low_freqs: int) -> np.ndarray:

# create the mask
mask = np.zeros(num_cols, dtype=bool)
Expand Down Expand Up @@ -415,7 +415,7 @@ def mask_func(
mask_negative = np.flip(mask_negative)

mask = np.fft.fftshift(np.concatenate((mask_positive, mask_negative)))
mask = mask | acs_mask
mask = np.logical_or(mask, acs_mask)

return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask))

Expand Down Expand Up @@ -781,20 +781,21 @@ class VariableDensityPoissonMaskFunc(BaseMaskFunc):
def __init__(
self,
accelerations: Union[List[Number], Tuple[Number, ...]],
center_scales: Union[List[float], Tuple[float, ...]],
center_fractions: Union[List[float], Tuple[float, ...]],
crop_corner: Optional[bool] = False,
max_attempts: Optional[int] = 10,
tol: Optional[float] = 0.2,
slopes: Optional[Union[List[float], Tuple[float, ...]]] = None,
**kwargs,
):
"""Inits :class:`VariableDensityPoissonMaskFunc`.
Parameters
----------
accelerations: list or tuple of positive numbers
Amount of under-sampling.
center_scales: list or tuple of floats
Must have the same lenght as `accelerations`. Amount of center fully-sampling.
center_fractions: list or tuple of floats
Must have the same length as `accelerations`. Amount of center fully-sampling.
For center_scale='r', then a centered disk area with radius equal to
:math:`R = \sqrt{{n_r}^2 + {n_c}^2} \times r` will be fully sampled, where :math:`n_r` and :math:`n_c`
denote the input shape.
Expand All @@ -810,7 +811,7 @@ def __init__(
"""
super().__init__(
accelerations=accelerations,
center_fractions=center_scales,
center_fractions=center_fractions,
uniform_range=False,
)
self.crop_corner = crop_corner
Expand Down Expand Up @@ -864,9 +865,9 @@ def mask_func(
if return_acs:
return torch.from_numpy(
self.centered_disk_mask((num_rows, num_cols), center_fraction)[np.newaxis, ..., np.newaxis]
)
).bool()
mask = self.poisson(num_rows, num_cols, center_fraction, acceleration, cython_seed)
return torch.from_numpy(mask[np.newaxis, ..., np.newaxis])
return torch.from_numpy(mask[np.newaxis, ..., np.newaxis]).bool()

def poisson(
self,
Expand Down
3 changes: 3 additions & 0 deletions direct/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,9 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
num_z = kspace.shape[1]
kspace[:, int(np.ceil(num_z * self.sampling_rate_slice_encode)) :, :] = 0.0 + 0.0 * 1j

sample["padding_left"] = 0
sample["padding_right"] = np.all(np.abs(kspace).sum(-1) == 0, axis=0).nonzero()[0][0]

# Downstream code expects the coils to be at the first axis.
sample["kspace"] = np.ascontiguousarray(kspace.transpose(2, 0, 1))

Expand Down
19 changes: 12 additions & 7 deletions direct/data/datasets_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Classes holding the typed configurations for the datasets."""

from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import List, Optional, Tuple

from omegaconf import MISSING
Expand All @@ -14,15 +14,20 @@

@dataclass
class TransformsConfig(BaseConfig):
crop: Optional[Tuple[int, int]] = field(default_factory=lambda: (320, 320))
crop_type: str = "uniform"
estimate_sensitivity_maps: bool = False
masking: MaskingConfig = MaskingConfig()
crop: Optional[Tuple[int, int]] = None
crop_type: Optional[str] = "uniform"
image_center_crop: bool = False
padding_eps: float = 0.001
estimate_sensitivity_maps: bool = True
estimate_body_coil_image: bool = False
sensitivity_maps_gaussian: Optional[float] = 0.7
image_center_crop: bool = True
delete_acs_mask: bool = True
delete_kspace: bool = True
image_recon_type: str = "rss"
pad_coils: Optional[int] = None
scaling_key: Optional[str] = None
masking: MaskingConfig = MaskingConfig()
scaling_key: Optional[str] = "masked_kspace"
use_seed: bool = True


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions direct/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __init__(
self.logger.info("Attempting to load %s filenames.", len(filenames_filter))
filenames = filenames_filter

filenames = [pathlib.Path(_) for _ in filenames]

if regex_filter:
filenames = [_ for _ in filenames if re.match(regex_filter, str(_))]

Expand Down
Loading

0 comments on commit 520edd4

Please sign in to comment.