diff --git a/direct/common/subsample_config.py b/direct/common/subsample_config.py
index 89cccbb3..2640c6d4 100644
--- a/direct/common/subsample_config.py
+++ b/direct/common/subsample_config.py
@@ -1,23 +1,23 @@
-# Copyright (c) DIRECT Contributors
-
-from __future__ import annotations
-
-from dataclasses import dataclass
-from typing import Optional
-
-from omegaconf import MISSING
-
-from direct.config.defaults import BaseConfig
-from direct.types import MaskFuncMode
-
-
-@dataclass
-class MaskingConfig(BaseConfig):
- name: str = MISSING
- accelerations: tuple[float, ...] = (5.0,)
- center_fractions: Optional[tuple[float, ...]] = (0.1,)
- uniform_range: bool = False
- mode: MaskFuncMode = MaskFuncMode.STATIC
-
- val_accelerations: tuple[float, ...] = (5.0, 10.0)
- val_center_fractions: Optional[tuple[float, ...]] = (0.1, 0.05)
+# Copyright (c) DIRECT Contributors
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Optional
+
+from omegaconf import MISSING
+
+from direct.config import BaseConfig
+from direct.types import MaskFuncMode
+
+
+@dataclass
+class MaskingConfig(BaseConfig):
+ name: str = MISSING
+ accelerations: tuple[float, ...] = (5.0,)
+ center_fractions: Optional[tuple[float, ...]] = (0.1,)
+ uniform_range: bool = False
+ mode: MaskFuncMode = MaskFuncMode.STATIC
+
+ val_accelerations: tuple[float, ...] = (5.0, 10.0)
+ val_center_fractions: Optional[tuple[float, ...]] = (0.1, 0.05)
diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py
index ee915dec..68dc4487 100644
--- a/direct/data/mri_transforms.py
+++ b/direct/data/mri_transforms.py
@@ -1,2793 +1,2795 @@
-# Copyright (c) DIRECT Contributors
-
-"""The `direct.data.mri_transforms` module contains mri transformations utilized to transform or augment k-space data,
-used for DIRECT's training pipeline. They can be also used individually by importing them into python scripts."""
-
-from __future__ import annotations
-
-import contextlib
-import copy
-import functools
-import logging
-import random
-import warnings
-from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union
-
-import numpy as np
-import torch
-
-from direct.algorithms.mri_algorithms import EspiritCalibration
-from direct.data import transforms as T
-from direct.exceptions import ItemNotFoundException
-from direct.registration.elastic_deformation import RandomElasticDeformationModule
-from direct.registration.registration import DemonsFilterType, DisplacementModule, DisplacementTransformType
-from direct.ssl.ssl import (
- GaussianMaskSplitterModule,
- HalfMaskSplitterModule,
- HalfSplitType,
- MaskSplitterType,
- SSLTransformMaskPrefixes,
- UniformMaskSplitterModule,
-)
-from direct.types import DirectEnum, IntegerListOrTupleString, KspaceKey, TransformKey
-from direct.utils import DirectModule, DirectTransform
-from direct.utils.asserts import assert_complex
-
-logger = logging.getLogger(__name__)
-
-
-@contextlib.contextmanager
-def temp_seed(rng, seed):
- state = rng.get_state()
- rng.seed(seed)
- try:
- yield
- finally:
- rng.set_state(state)
-
-
-class Compose(DirectTransform):
- """Compose several transformations together, for instance ClipAndScale and a flip.
-
- Code based on torchvision: https://github.com/pytorch/vision, but got forked from there as torchvision has some
- additional dependencies.
- """
-
- def __init__(self, transforms: Iterable[Callable]) -> None:
- """Inits :class:`Compose`.
-
- Parameters
- ----------
- transforms: Iterable[Callable]
- List of transforms.
- """
- super().__init__()
- self.transforms = transforms
-
- def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Calls :class:`Compose`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dict sample.
-
- Returns
- -------
- dict[str, Any]
- Dict sample transformed by `transforms`.
- """
- for transform in self.transforms:
- sample = transform(sample)
-
- return sample
-
- def __repr__(self):
- """Representation of :class:`Compose`."""
- repr_string = self.__class__.__name__ + "("
- for transform in self.transforms:
- repr_string += "\n"
- repr_string += f" {transform},"
- repr_string = repr_string[:-1] + "\n)"
- return repr_string
-
-
-class RandomRotation(DirectTransform):
- r"""Random :math:`k`-space rotation.
-
- Performs a random rotation with probability :math:`p`. Rotation degrees must be multiples of 90.
- """
-
- def __init__(
- self,
- degrees: Sequence[int] = (-90, 90),
- p: float = 0.5,
- keys_to_rotate: tuple[TransformKey, ...] = (TransformKey.KSPACE,),
- ) -> None:
- r"""Inits :class:`RandomRotation`.
-
- Parameters
- ----------
- degrees: sequence of ints
- Degrees of rotation. Must be a multiple of 90. If len(degrees) > 1, then a degree will be chosen at random.
- Default: (-90, 90).
- p: float
- Probability of rotation. Default: 0.5
- keys_to_rotate : tuple of TransformKeys
- Keys to rotate. Default: "kspace".
- """
- super().__init__()
-
- assert all(degree % 90 == 0 for degree in degrees)
-
- self.degrees = degrees
- self.p = p
- self.keys_to_rotate = keys_to_rotate
-
- def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Calls :class:`RandomRotation`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dict sample.
-
- Returns
- -------
- dict[str, Any]
- Sample with rotated values of `keys_to_rotate`.
- """
- if random.SystemRandom().random() <= self.p:
- degree = random.SystemRandom().choice(self.degrees)
- k = degree // 90
- for key in self.keys_to_rotate:
- if key in sample:
- value = T.view_as_complex(sample[key].clone())
- sample[key] = T.view_as_real(torch.rot90(value, k=k, dims=(-2, -1)))
-
- # If rotated by multiples of (n + 1) * 90 degrees, reconstruction size also needs to change
- reconstruction_size = sample.get("reconstruction_size", None)
- if reconstruction_size and (k % 2) == 1:
- sample["reconstruction_size"] = (
- reconstruction_size[:-3] + reconstruction_size[-3:-1][::-1] + reconstruction_size[-1:]
- )
-
- return sample
-
-
-class RandomFlipType(DirectEnum):
- HORIZONTAL = "horizontal"
- VERTICAL = "vertical"
- RANDOM = "random"
- BOTH = "both"
-
-
-class RandomFlip(DirectTransform):
- r"""Random k-space flip transform.
-
- Performs a random flip with probability :math:`p`. Flip can be horizontal, vertical, or a random choice of the two.
- """
-
- def __init__(
- self,
- flip: RandomFlipType = RandomFlipType.RANDOM,
- p: float = 0.5,
- keys_to_flip: tuple[TransformKey, ...] = (TransformKey.KSPACE,),
- ) -> None:
- r"""Inits :class:`RandomFlip`.
-
- Parameters
- ----------
- flip : RandomFlipType
- Horizontal, vertical, or random choice of the two. Default: RandomFlipType.RANDOM.
- p : float
- Probability of flip. Default: 0.5
- keys_to_flip : tuple of TransformKeys
- Keys to flip. Default: "kspace".
- """
- super().__init__()
-
- self.flip = flip
- self.p = p
- self.keys_to_flip = keys_to_flip
-
- def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Calls :class:`RandomFlip`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dict sample.
-
- Returns
- -------
- dict[str, Any]
- Sample with flipped values of `keys_to_flip`.
- """
- if random.SystemRandom().random() <= self.p:
- dims = (
- (-2,)
- if self.flip == "horizontal"
- else (
- (-1,)
- if self.flip == "vertical"
- else (-2, -1) if self.flip == "both" else (random.SystemRandom().choice([-2, -1]),)
- )
- )
-
- for key in self.keys_to_flip:
- if key in sample:
- value = T.view_as_complex(sample[key].clone())
- value = torch.flip(value, dims=dims)
- sample[key] = T.view_as_real(value)
-
- return sample
-
-
-class RandomReverse(DirectTransform):
- r"""Random reverse of the order along a given dimension of a PyTorch tensor."""
-
- def __init__(
- self,
- dim: int = 1,
- p: float = 0.5,
- keys_to_reverse: tuple[TransformKey, ...] = (TransformKey.KSPACE,),
- ) -> None:
- r"""Inits :class:`RandomReverse`.
-
- Parameters
- ----------
- dim : int
- Dimension along to perform reversion. Typically, this is for time or slice dimension. Default: 2.
- p : float
- Probability of flip. Default: 0.5
- keys_to_reverse : tuple of TransformKeys
- Keys to reverse. Default: "kspace".
- """
- super().__init__()
-
- self.dim = dim
- self.p = p
- self.keys_to_reverse = keys_to_reverse
-
- def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Calls :class:`RandomReverse`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dict sample.
-
- Returns
- -------
- dict[str, Any]
- Sample with flipped values of `keys_to_flip`.
- """
- if random.SystemRandom().random() <= self.p:
- dim = self.dim
- for key in self.keys_to_reverse:
- if key in sample:
- tensor = sample[key].clone()
-
- if dim < 0:
- dim += tensor.dim()
-
- tensor = T.view_as_complex(tensor)
-
- index = [slice(None)] * tensor.dim()
- index[dim] = torch.arange(tensor.size(dim) - 1, -1, -1, dtype=torch.long)
-
- tensor = tensor[tuple(index)]
-
- sample[key] = T.view_as_real(tensor)
-
- return sample
-
-
-class CreateSamplingMask(DirectTransform):
- """Data Transformer for training MRI reconstruction models.
-
- Creates sampling mask.
- """
-
- def __init__(
- self,
- mask_func: Callable,
- shape: Optional[tuple[int, ...]] = None,
- use_seed: bool = True,
- return_acs: bool = False,
- ) -> None:
- """Inits :class:`CreateSamplingMask`.
-
- Parameters
- ----------
- mask_func: Callable
- A function which creates a sampling mask of the appropriate shape.
- shape: tuple, optional
- Sampling mask shape. Default: None.
- use_seed: bool
- If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
- the same mask every time. Default: True.
- return_acs: bool
- If True, it will generate an ACS mask. Default: False.
- """
- super().__init__()
- self.mask_func = mask_func
- self.shape = shape
- self.use_seed = use_seed
- self.return_acs = return_acs
-
- def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Calls :class:`CreateSamplingMask`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dict sample.
-
- Returns
- -------
- dict[str, Any]
- Sample with `sampling_mask` key.
- """
- if not self.shape:
- shape = sample["kspace"].shape[1:]
- elif any(_ is None for _ in self.shape): # Allow None as values.
- kspace_shape = list(sample["kspace"].shape[1:-1])
- shape = tuple(_ if _ else kspace_shape[idx] for idx, _ in enumerate(self.shape)) + (2,)
- else:
- shape = self.shape + (2,)
-
- seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"])))
-
- sampling_mask = self.mask_func(shape=shape, seed=seed, return_acs=False)
-
- if sampling_mask.ndim == 5:
- acceleration = [
- np.prod(sampling_mask[0, _].shape) / sampling_mask[0, _].sum() for _ in range(sampling_mask.shape[1])
- ]
- sample["acceleration"] = torch.tensor(acceleration, dtype=torch.float32).unsqueeze(0)
- else:
- sample["acceleration"] = (np.prod(sampling_mask.shape) / sampling_mask.sum()).unsqueeze(0)
-
- if "padding" in sample:
- sampling_mask = T.apply_padding(sampling_mask, sample["padding"])
-
- # Shape 3D: (1, 1, height, width, 1), 2D: (1, height, width, 1)
- sample["sampling_mask"] = sampling_mask
-
- if self.return_acs:
- sample["acs_mask"] = self.mask_func(shape=shape, seed=seed, return_acs=True)
- if sampling_mask.ndim == 5:
- center_fraction = [
- sample["acs_mask"][0, _].sum() / np.prod(sample["acs_mask"][0, _].shape)
- for _ in range(sample["acs_mask"].shape[1])
- ]
- sample["center_fraction"] = torch.tensor(center_fraction, dtype=torch.float32).unsqueeze(0)
- else:
- sample["center_fraction"] = (sample["acs_mask"].sum() / np.prod(sample["acs_mask"].shape)).unsqueeze(0)
- return sample
-
-
-class ApplyMaskModule(DirectModule):
- """Data Transformer for training MRI reconstruction models.
-
- Masks the input k-space (with key `input_kspace_key`) using a sampling mask with key `sampling_mask_key` onto
- a new masked k-space with key `target_kspace_key`.
- """
-
- def __init__(
- self,
- sampling_mask_key: str = "sampling_mask",
- input_kspace_key: KspaceKey = KspaceKey.KSPACE,
- target_kspace_key: KspaceKey = KspaceKey.MASKED_KSPACE,
- ) -> None:
- """Inits :class:`ApplyMaskModule`.
-
- Parameters
- ----------
- sampling_mask_key: str
- Default: "sampling_mask".
- input_kspace_key: KspaceKey
- Default: KspaceKey.KSPACE.
- target_kspace_key: KspaceKey
- Default KspaceKey.MASKED_KSPACE.
- """
- super().__init__()
- self.logger = logging.getLogger(type(self).__name__)
-
- self.sampling_mask_key = sampling_mask_key
- self.input_kspace_key = input_kspace_key
- self.target_kspace_key = target_kspace_key
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`ApplyMaskModule`.
-
- Applies mask with key `sampling_mask_key` onto kspace `input_kspace_key`. Result is stored as a tensor with
- key `target_kspace_key`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dict sample containing keys `sampling_mask_key` and `input_kspace_key`.
-
- Returns
- -------
- dict[str, Any]
- Sample with (new) key `target_kspace_key`.
- """
- if self.input_kspace_key not in sample:
- raise ValueError(f"Key {self.input_kspace_key} corresponding to `input_kspace_key` not found in sample.")
- input_kspace = sample[self.input_kspace_key]
-
- if self.sampling_mask_key not in sample:
- raise ValueError(f"Key {self.sampling_mask_key} corresponding to `sampling_mask_key` not found in sample.")
- sampling_mask = sample[self.sampling_mask_key]
-
- target_kspace, _ = T.apply_mask(input_kspace, sampling_mask)
- sample[self.target_kspace_key] = target_kspace
- return sample
-
-
-class CropKspace(DirectTransform):
- """Data Transformer for training MRI reconstruction models.
-
- Crops the k-space by:
- * It first projects the k-space to the image-domain via the backward operator,
- * It crops the back-projected k-space to specified shape or key,
- * It transforms the cropped back-projected k-space to the k-space domain via the forward operator.
- """
-
- def __init__(
- self,
- crop: Union[str, tuple[int, ...], list[int]],
- forward_operator: Callable = T.fft2,
- backward_operator: Callable = T.ifft2,
- image_space_center_crop: bool = False,
- random_crop_sampler_type: Optional[str] = "uniform",
- random_crop_sampler_use_seed: Optional[bool] = True,
- random_crop_sampler_gaussian_sigma: Optional[list[float]] = None,
- ) -> None:
- """Inits :class:`CropKspace`.
-
- Parameters
- ----------
- crop: tuple of ints or str
- Shape to crop the input to or a string pointing to a crop key (e.g. `reconstruction_size`).
- forward_operator: Callable
- The forward operator, e.g. some form of FFT (centered or uncentered).
- Default: :class:`direct.data.transforms.fft2`.
- backward_operator: Callable
- The backward operator, e.g. some form of inverse FFT (centered or uncentered).
- Default: :class:`direct.data.transforms.ifft2`.
- image_space_center_crop: bool
- If set, the crop in the data will be taken in the center
- random_crop_sampler_type: Optional[str]
- If "uniform" the random cropping will be done by uniformly sampling `crop`, as opposed to `gaussian` which
- will sample from a gaussian distribution. If `image_space_center_crop` is True, then this is ignored.
- Default: "uniform".
- random_crop_sampler_use_seed: bool
- If true, a pseudo-random number based on the filename is computed so that every slice of the volume
- is cropped the same way. Default: True.
- random_crop_sampler_gaussian_sigma: Optional[list[float]]
- Standard variance of the gaussian when `random_crop_sampler_type` is `gaussian`.
- If `image_space_center_crop` is True, then this is ignored. Default: None.
- """
- super().__init__()
- self.logger = logging.getLogger(type(self).__name__)
-
- self.image_space_center_crop = image_space_center_crop
-
- if not (isinstance(crop, (Iterable, str))):
- raise ValueError(
- f"Invalid input for `crop`. Received {crop}. Can be a list of tuple of integers or a string."
- )
- self.crop = crop
-
- if image_space_center_crop:
- self.crop_func = T.complex_center_crop
- else:
- self.crop_func = functools.partial(
- T.complex_random_crop,
- sampler=random_crop_sampler_type,
- sigma=random_crop_sampler_gaussian_sigma,
- )
- self.random_crop_sampler_use_seed = random_crop_sampler_use_seed
-
- self.forward_operator = forward_operator
- self.backward_operator = backward_operator
-
- def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Calls :class:`CropKspace`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dict sample containing key `kspace`.
-
- Returns
- -------
- dict[str, Any]
- Cropped and masked sample.
- """
-
- kspace = sample["kspace"] # shape (coil, [slice/time], height, width, complex=2)
-
- dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D
-
- backprojected_kspace = self.backward_operator(kspace, dim=dim) # shape (coil, height, width, complex=2)
-
- if isinstance(self.crop, IntegerListOrTupleString):
- crop_shape = IntegerListOrTupleString(self.crop)
- elif isinstance(self.crop, str):
- assert self.crop in sample, f"Not found {self.crop} key in sample."
- crop_shape = sample[self.crop][:-1]
- else:
- if kspace.ndim == 5 and len(self.crop) == 2:
- crop_shape = (kspace.shape[1],) + tuple(self.crop)
- else:
- crop_shape = tuple(self.crop)
-
- cropper_args = {
- "data_list": [backprojected_kspace],
- "crop_shape": crop_shape,
- "contiguous": False,
- }
- if not self.image_space_center_crop:
- cropper_args["seed"] = (
- None if not self.random_crop_sampler_use_seed else tuple(map(ord, str(sample["filename"])))
- )
- cropped_backprojected_kspace = self.crop_func(**cropper_args)
-
- if "sampling_mask" in sample:
- sample["sampling_mask"] = T.complex_center_crop(
- sample["sampling_mask"], (1,) + tuple(crop_shape)[1:] if kspace.ndim == 5 else crop_shape
- )
- sample["acs_mask"] = T.complex_center_crop(
- sample["acs_mask"], (1,) + tuple(crop_shape)[1:] if kspace.ndim == 5 else crop_shape
- )
-
- # Compute new k-space for the cropped_backprojected_kspace
- # shape (coil, [slice/time], new_height, new_width, complex=2)
- sample["kspace"] = self.forward_operator(cropped_backprojected_kspace, dim=dim) # The cropped kspace
-
- return sample
-
-
-class RescaleMode(DirectEnum):
- AREA = "area"
- BICUBIC = "bicubic"
- BILINEAR = "bilinear"
- NEAREST = "nearest"
- NEAREST_EXACT = "nearest-exact"
- TRILINEAR = "trilinear"
-
-
-class RescaleKspace(DirectTransform):
- """Rescale k-space (downsample/upsample) module.
-
- Rescales the k-space:
- * It first projects the k-space to the image-domain via the backward operator,
- * It rescales the back-projected k-space to specified shape,
- * It transforms the rescaled back-projected k-space to the k-space domain via the forward operator.
-
- Parameters
- ----------
- shape : tuple or list of ints
- Shape to rescale the input. Must be correspond to (height, width).
- forward_operator : Callable
- The forward operator, e.g. some form of FFT (centered or uncentered).
- Default: :class:`direct.data.transforms.fft2`.
- backward_operator : Callable
- The backward operator, e.g. some form of inverse FFT (centered or uncentered).
- Default: :class:`direct.data.transforms.ifft2`.
- rescale_mode : RescaleMode
- Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR,
- RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are
- supported for 2D or 3D data. Default: RescaleMode.NEAREST.
- kspace_key : KspaceKey
- K-space key. Default: KspaceKey.KSPACE.
- rescale_2d_if_3d : bool, optional
- If True and input k-space data is 3D, rescaling will be done only on the height and width dimensions.
- Default: False.
-
- Note
- ----
- If the input k-space data is 3D, rescaling will be done only on the height and width dimensions if
- `rescale_2d_if_3d` is set to True.
- """
-
- def __init__(
- self,
- shape: Union[tuple[int, int], list[int]],
- forward_operator: Callable = T.fft2,
- backward_operator: Callable = T.ifft2,
- rescale_mode: RescaleMode = RescaleMode.NEAREST,
- kspace_key: KspaceKey = KspaceKey.KSPACE,
- rescale_2d_if_3d: Optional[bool] = None,
- ) -> None:
- """Inits :class:`RescaleKspace`.
-
- Parameters
- ----------
- shape : tuple or list of ints
- Shape to rescale the input. Must be correspond to (height, width).
- forward_operator : Callable
- The forward operator, e.g. some form of FFT (centered or uncentered).
- Default: :class:`direct.data.transforms.fft2`.
- backward_operator : Callable
- The backward operator, e.g. some form of inverse FFT (centered or uncentered).
- Default: :class:`direct.data.transforms.ifft2`.
- rescale_mode : RescaleMode
- Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR,
- RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are
- supported for 2D or 3D data. Default: RescaleMode.NEAREST.
- kspace_key : KspaceKey
- K-space key. Default: KspaceKey.KSPACE.
- rescale_2d_if_3d : bool, optional
- If True and input k-space data is 3D, rescaling will be done only on the height and width dimensions,
- by combining the slice/time dimension with the batch dimension.
- Default: False.
- """
- super().__init__()
- self.logger = logging.getLogger(type(self).__name__)
-
- if len(shape) not in [2, 3]:
- raise ValueError(
- f"Shape should be a list or tuple of two integers if 2D or three integers if 3D. "
- f"Received: {shape}."
- )
- self.shape = shape
- self.forward_operator = forward_operator
- self.backward_operator = backward_operator
- self.rescale_mode = rescale_mode
- self.kspace_key = kspace_key
-
- self.rescale_2d_if_3d = rescale_2d_if_3d
- if rescale_2d_if_3d and len(shape) == 3:
- raise ValueError("Shape cannot have a length of 3 when rescale_2d_if_3d is set to True.")
-
- def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Calls :class:`RescaleKspace`.
-
- Parameters
- ----------
- sample: Dict[str, Any]
- Dict sample containing key `kspace`.
-
- Returns
- -------
- Dict[str, Any]
- Cropped and masked sample.
- """
- kspace = sample[self.kspace_key] # shape (coil, [slice/time], height, width, complex=2)
-
- dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D
-
- backprojected_kspace = self.backward_operator(kspace, dim=dim)
-
- if kspace.ndim == 5 and self.rescale_2d_if_3d:
- backprojected_kspace = backprojected_kspace.permute(1, 0, 2, 3, 4)
-
- if (kspace.ndim == 4) or (kspace.ndim == 5 and not self.rescale_2d_if_3d):
- backprojected_kspace = backprojected_kspace.unsqueeze(0)
-
- rescaled_backprojected_kspace = T.complex_image_resize(backprojected_kspace, self.shape, self.rescale_mode)
-
- if (kspace.ndim == 4) or (kspace.ndim == 5 and not self.rescale_2d_if_3d):
- rescaled_backprojected_kspace = rescaled_backprojected_kspace.squeeze(0)
-
- if kspace.ndim == 5 and self.rescale_2d_if_3d:
- rescaled_backprojected_kspace = rescaled_backprojected_kspace.permute(1, 0, 2, 3, 4)
-
- # Compute new k-space from rescaled_backprojected_kspace
- # shape (coil, [slice/time if rescale_2d_if_3d else new_slc_or_time], new_height, new_width, complex=2)
- sample[self.kspace_key] = self.forward_operator(rescaled_backprojected_kspace, dim=dim) # The rescaled kspace
-
- return sample
-
-
-class PadKspace(DirectTransform):
- """Pad k-space with zeros to desired shape module.
-
- Rescales the k-space by:
- * It first projects the k-space to the image-domain via the backward operator,
- * It pads the back-projected k-space to specified shape,
- * It transforms the rescaled back-projected k-space to the k-space domain via the forward operator.
-
- Parameters
- ----------
- pad_shape : tuple or list of ints
- Shape to zero-pad the input. Must be correspond to (height, width) or (slice/time, height, width).
- forward_operator : Callable
- The forward operator, e.g. some form of FFT (centered or uncentered).
- Default: :class:`direct.data.transforms.fft2`.
- backward_operator : Callable
- The backward operator, e.g. some form of inverse FFT (centered or uncentered).
- Default: :class:`direct.data.transforms.ifft2`.
- kspace_key : KspaceKey
- K-space key. Default: KspaceKey.KSPACE.
- """
-
- def __init__(
- self,
- pad_shape: Union[tuple[int, ...], list[int]],
- forward_operator: Callable = T.fft2,
- backward_operator: Callable = T.ifft2,
- kspace_key: KspaceKey = KspaceKey.KSPACE,
- ) -> None:
- """Inits :class:`RescaleKspace`.
-
- Parameters
- ----------
- pad_shape : tuple or list of ints
- Shape to zero-pad the input. Must be correspond to (height, width) or (slice/time, height, width).
- forward_operator : Callable
- The forward operator, e.g. some form of FFT (centered or uncentered).
- Default: :class:`direct.data.transforms.fft2`.
- backward_operator : Callable
- The backward operator, e.g. some form of inverse FFT (centered or uncentered).
- Default: :class:`direct.data.transforms.ifft2`.
- kspace_key : KspaceKey
- K-space key. Default: KspaceKey.KSPACE.
- """
- super().__init__()
- self.logger = logging.getLogger(type(self).__name__)
-
- if len(pad_shape) not in [2, 3]:
- raise ValueError(f"Shape should be a list or tuple of two or three integers. Received: {pad_shape}.")
-
- self.shape = pad_shape
- self.forward_operator = forward_operator
- self.backward_operator = backward_operator
- self.kspace_key = kspace_key
-
- def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Calls :class:`PadKspace`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dict sample containing key `kspace`.
-
- Returns
- -------
- dict[str, Any]
- Cropped and masked sample.
- """
- kspace = sample[self.kspace_key] # shape (coil, [slice or time], height, width, complex=2)
- shape = kspace.shape
-
- sample["original_size"] = shape[1:-1]
-
- dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D
-
- backprojected_kspace = self.backward_operator(kspace, dim=dim)
- backprojected_kspace = T.view_as_complex(backprojected_kspace)
-
- padded_backprojected_kspace = T.pad_tensor(backprojected_kspace, self.shape)
- padded_backprojected_kspace = T.view_as_real(padded_backprojected_kspace)
-
- # shape (coil, [slice or time], height, width, complex=2)
- sample[self.kspace_key] = self.forward_operator(padded_backprojected_kspace, dim=dim) # The padded kspace
-
- return sample
-
-
-class ComputeZeroPadding(DirectTransform):
- r"""Computes zero padding present in multi-coil kspace input.
-
- Zero-padding is computed from multi-coil kspace with no signal contribution, i.e. its magnitude
- is really close to zero:
-
- .. math ::
-
- \text{padding} = \sum_{i=1}^{n_c} |y_i| < \frac{1}{n_x \cdot n_y}
- \sum_{j=1}^{n_x \cdot n_y} \big\{\sum_{i=1}^{n_c} |y_i|\big\} * \epsilon.
- """
-
- def __init__(
- self,
- kspace_key: KspaceKey = KspaceKey.KSPACE,
- padding_key: str = "padding",
- eps: Optional[float] = 0.0001,
- ) -> None:
- """Inits :class:`ComputeZeroPadding`.
-
- Parameters
- ----------
- kspace_key: KspaceKey
- K-space key. Default: KspaceKey.KSPACE.
- padding_key: str
- Target key. Default: "padding".
- eps: float
- Epsilon to multiply sum of signals. If really high, probably no padding will be produced. Default: 0.0001.
- """
- super().__init__()
- self.kspace_key = kspace_key
- self.padding_key = padding_key
- self.eps = eps
-
- def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]:
- """Updates sample with a key `padding_key` with value a binary tensor.
-
- Non-zero entries indicate samples in kspace with key `kspace_key` which have minor contribution, i.e. padding.
-
- Parameters
- ----------
- sample : dict[str, Any]
- Dict sample containing key `kspace_key`.
- coil_dim : int
- Coil dimension. Default: 0.
-
- Returns
- -------
- sample : dict[str, Any]
- Dict sample containing key `padding_key`.
- """
- if self.eps is None:
- return sample
- shape = sample[self.kspace_key].shape
-
- kspace = T.modulus(sample[self.kspace_key].clone()).sum(coil_dim)
-
- if len(shape) == 5: # Check if 3D data
- # Assumes that slice dim is 0
- kspace = kspace.sum(0)
-
- padding = (kspace < (torch.mean(kspace) * self.eps)).to(kspace.device)
-
- if len(shape) == 5:
- padding = padding.unsqueeze(0)
-
- padding = padding.unsqueeze(coil_dim).unsqueeze(-1)
- sample[self.padding_key] = padding
-
- return sample
-
-
-class ApplyZeroPadding(DirectTransform):
- """Applies zero padding present in multi-coil kspace input."""
-
- def __init__(self, kspace_key: KspaceKey = KspaceKey.KSPACE, padding_key: str = "padding") -> None:
- """Inits :class:`ApplyZeroPadding`.
-
- Parameters
- ----------
- kspace_key: KspaceKey
- K-space key. Default: KspaceKey.KSPACE.
- padding_key: str
- Target key. Default: "padding".
- """
- super().__init__()
- self.kspace_key = kspace_key
- self.padding_key = padding_key
-
- def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]:
- """Applies zero padding on `kspace_key` with value a binary tensor.
-
- Parameters
- ----------
- sample : dict[str, Any]
- Dict sample containing key `kspace_key`.
- coil_dim : int
- Coil dimension. Default: 0.
-
- Returns
- -------
- sample : dict[str, Any]
- Dict sample containing key `padding_key`.
- """
-
- sample[self.kspace_key] = T.apply_padding(sample[self.kspace_key], sample[self.padding_key])
-
- return sample
-
-
-class ReconstructionType(DirectEnum):
- """Reconstruction method for :class:`ComputeImage` transform."""
-
- IFFT = "ifft"
- RSS = "rss"
- COMPLEX = "complex"
- COMPLEX_MOD = "complex_mod"
- SENSE = "sense"
- SENSE_MOD = "sense_mod"
-
-
-class ComputeImageModule(DirectModule):
- """Compute Image transform."""
-
- def __init__(
- self,
- kspace_key: KspaceKey,
- target_key: str,
- backward_operator: Callable,
- type_reconstruction: ReconstructionType = ReconstructionType.RSS,
- ) -> None:
- """Inits :class:`ComputeImageModule`.
-
- Parameters
- ----------
- kspace_key: KspaceKey
- K-space key.
- target_key: str
- Target key.
- backward_operator: callable
- The backward operator, e.g. some form of inverse FFT (centered or uncentered).
- type_reconstruction: ReconstructionType
- Type of reconstruction. Can be ReconstructionType.RSS, ReconstructionType.COMPLEX,
- ReconstructionType.COMPLEX_MOD, ReconstructionType.SENSE, ReconstructionType.SENSE_MOD or
- ReconstructionType.IFFT. Default: ReconstructionType.RSS.
- """
- super().__init__()
- self.backward_operator = backward_operator
- self.kspace_key = kspace_key
- self.target_key = target_key
- self.type_reconstruction = type_reconstruction
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`ComputeImageModule`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Contains key kspace_key with value a torch.Tensor of shape (coil,\*spatial_dims, complex=2).
-
- Returns
- -------
- sample: dict
- Contains key target_key with value a torch.Tensor of shape (\*spatial_dims) if `type_reconstruction` is
- ReconstructionType.RSS, ReconstructionType.COMPLEX_MOD, ReconstructionType.SENSE_MOD,
- and of shape (\*spatial_dims, complex_dim=2) otherwise.
- """
- kspace_data = sample[self.kspace_key]
- dim = self.spatial_dims.TWO_D if kspace_data.ndim == 5 else self.spatial_dims.THREE_D
- # Get complex-valued data solution
- image = self.backward_operator(kspace_data, dim=dim)
- if self.type_reconstruction == ReconstructionType.IFFT:
- sample[self.target_key] = image
- elif self.type_reconstruction in [
- ReconstructionType.COMPLEX,
- ReconstructionType.COMPLEX_MOD,
- ]:
- sample[self.target_key] = image.sum(self.coil_dim)
- elif self.type_reconstruction == ReconstructionType.RSS:
- sample[self.target_key] = T.root_sum_of_squares(image, dim=self.coil_dim)
- else:
- if "sensitivity_map" not in sample:
- raise ItemNotFoundException(
- "sensitivity map",
- "Sensitivity map is required for SENSE reconstruction.",
- )
- sample[self.target_key] = T.complex_multiplication(T.conjugate(sample["sensitivity_map"]), image).sum(
- self.coil_dim
- )
- if self.type_reconstruction in [
- ReconstructionType.COMPLEX_MOD,
- ReconstructionType.SENSE_MOD,
- ]:
- sample[self.target_key] = T.modulus(sample[self.target_key], self.complex_dim)
- return sample
-
-
-class EstimateBodyCoilImage(DirectTransform):
- """Estimates body coil image."""
-
- def __init__(self, mask_func: Callable, backward_operator: Callable, use_seed: bool = True) -> None:
- """Inits :class:`EstimateBodyCoilImage'.
-
- Parameters
- ----------
- mask_func: Callable
- A function which creates a sampling mask of the appropriate shape.
- backward_operator: callable
- The backward operator, e.g. some form of inverse FFT (centered or uncentered).
- use_seed: bool
- If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
- the same mask every time. Default: True.
- """
- super().__init__()
- self.mask_func = mask_func
- self.use_seed = use_seed
- self.backward_operator = backward_operator
-
- def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]:
- """Calls :class:`EstimateBodyCoilImage`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Contains key kspace_key with value a torch.Tensor of shape (coil, ..., complex=2).
- coil_dim: int
- Coil dimension. Default: 0.
-
- Returns
- ----------
- sample: dict[str, Any]
- Contains key `"body_coil_image`.
- """
- kspace = sample["kspace"]
-
- # We need to create an ACS mask based on the shape of this kspace, as it can be cropped.
- seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"])))
- kspace_shape = tuple(sample["kspace"].shape[-3:])
- acs_mask = self.mask_func(shape=kspace_shape, seed=seed, return_acs=True)
-
- kspace = acs_mask * kspace + 0.0
- dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D
- acs_image = self.backward_operator(kspace, dim=dim)
-
- sample["body_coil_image"] = T.root_sum_of_squares(acs_image, dim=coil_dim)
- return sample
-
-
-class SensitivityMapType(DirectEnum):
- ESPIRIT = "espirit"
- RSS_ESTIMATE = "rss_estimate"
- UNIT = "unit"
-
-
-class EstimateSensitivityMapModule(DirectModule):
- """Data Transformer for training MRI reconstruction models.
-
- Estimates sensitivity maps given masked k-space data using one of three methods:
-
- * Unit: unit sensitivity map in case of single coil acquisition.
- * RSS-estimate: sensitivity maps estimated by using the root-sum-of-squares of the autocalibration-signal.
- * ESPIRIT: sensitivity maps estimated with the ESPIRIT method [1]_. Note that this is currently not
- implemented for 3D data, and attempting to use it in such cases will result in a NotImplementedError.
-
- References
- ----------
-
- .. [1] Uecker M, Lai P, Murphy MJ, Virtue P, Elad M, Pauly JM, Vasanawala SS, Lustig M. ESPIRiT--an eigenvalue
- approach to autocalibrating parallel MRI: where SENSE meets GRAPPA. Magn Reson Med. 2014 Mar;71(3):990-1001.
- doi: 10.1002/mrm.24751. PMID: 23649942; PMCID: PMC4142121.
- """
-
- def __init__(
- self,
- kspace_key: KspaceKey = KspaceKey.ACS_KSPACE,
- backward_operator: Callable = T.ifft2,
- type_of_map: Optional[SensitivityMapType] = SensitivityMapType.RSS_ESTIMATE,
- gaussian_sigma: Optional[float] = None,
- espirit_threshold: Optional[float] = 0.05,
- espirit_kernel_size: Optional[int] = 6,
- espirit_crop: Optional[float] = 0.95,
- espirit_max_iters: Optional[int] = 30,
- ) -> None:
- """Inits :class:`EstimateSensitivityMapModule`.
-
- Parameters
- ----------
- kspace_key: KspaceKey
- K-space key to compute the ACS image from. If `kspace_key` is not `KspaceKey.ACS_KSPACE`,
- the ACS mask should be provided in the sample. Default: KspaceKey.ACS_KSPACE.
- backward_operator: callable
- The backward operator, e.g. some form of inverse FFT (centered or uncentered).
- type_of_map: SensitivityMapType, optional
- Type of map to estimate. Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or
- SensitivityMapType.ESPIRIT. Default: SensitivityMapType.RSS_ESTIMATE.
- gaussian_sigma: float, optional
- If non-zero, acs_image well be calculated
- espirit_threshold: float, optional
- Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
- Default: 0.05.
- espirit_kernel_size: int, optional
- Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
- Default: 6.
- espirit_crop: float, optional
- Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
- Default: 0.95.
- espirit_max_iters: int, optional
- Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30.
- """
- super().__init__()
- self.backward_operator = backward_operator
- self.kspace_key = kspace_key
- self.type_of_map = type_of_map
-
- # RSS estimate attributes
- self.gaussian_sigma = gaussian_sigma
- # Espirit attributes
- if type_of_map == SensitivityMapType.ESPIRIT:
- self.espirit_calibrator = EspiritCalibration(
- backward_operator,
- espirit_threshold,
- espirit_kernel_size,
- espirit_crop,
- espirit_max_iters,
- kspace_key,
- )
- self.espirit_threshold = espirit_threshold
- self.espirit_kernel_size = espirit_kernel_size
- self.espirit_crop = espirit_crop
- self.espirit_max_iters = espirit_max_iters
-
- def estimate_acs_image(self, sample: dict[str, Any], width_dim: int = -2) -> torch.Tensor:
- """Estimates the autocalibration (ACS) image by sampling the k-space using the ACS mask.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Sample dictionary,
- width_dim: int
- Dimension corresponding to width. Default: -2.
-
- Returns
- -------
- acs_image: torch.Tensor
- Estimate of the ACS image.
- """
- kspace_data = sample[self.kspace_key]
-
- if self.kspace_key != KspaceKey.ACS_KSPACE:
- if TransformKey.ACS_MASK not in sample:
- raise ValueError("ACS mask is required for estimating ACS image from k-space but not found.")
- kspace_data = kspace_data * sample[TransformKey.ACS_MASK]
-
- if self.gaussian_sigma == 0 or not self.gaussian_sigma:
- kspace_acs = kspace_data + 0.0 # + 0.0 removes the sign of zeros.
- else:
- gaussian_mask = torch.linspace(-1, 1, kspace_data.size(width_dim), dtype=kspace_data.dtype)
- gaussian_mask = torch.exp(-((gaussian_mask / self.gaussian_sigma) ** 2))
- gaussian_mask_shape = torch.ones(len(kspace_data.shape)).int()
- gaussian_mask_shape[width_dim] = kspace_data.size(width_dim)
- gaussian_mask = gaussian_mask.reshape(tuple(gaussian_mask_shape))
- kspace_acs = kspace_data * gaussian_mask + 0.0
-
- # Get complex-valued data solution
- # Shape (batch, [slice/time], coil, height, width, complex=2)
- dim = self.spatial_dims.TWO_D if kspace_data.ndim == 5 else self.spatial_dims.THREE_D
- acs_image = self.backward_operator(kspace_acs, dim=dim)
-
- return acs_image
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Calculates sensitivity maps for the input sample.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Must contain key matching kspace_key with value a (complex) torch.Tensor
- of shape (coil, height, width, complex=2).
-
- Returns
- -------
- sample: dict[str, Any]
- Sample with key "sensitivity_map" with value the estimated sensitivity map.
- """
- kspace = sample[self.kspace_key] # shape (batch, coil, [slice/time], height, width, complex=2)
-
- if kspace.shape[self.coil_dim] == 1:
- warnings.warn(
- "Estimation of sensitivity map of Single-coil data. This warning will be displayed only once."
- )
- if "sensitivity_map" in sample:
- warnings.warn(
- "`sensitivity_map` is given, but will be overwritten. This warning will be displayed only once."
- )
-
- if self.type_of_map == SensitivityMapType.UNIT:
- sensitivity_map = torch.zeros(kspace.shape).float()
- # Assumes complex channel is last
- assert_complex(kspace, complex_last=True)
- sensitivity_map[..., 0] = 1.0
- # Shape (batch, coil, [slice/time], height, width, complex=2)
- sensitivity_map = sensitivity_map.to(kspace.device)
-
- elif self.type_of_map == SensitivityMapType.RSS_ESTIMATE:
- # Shape (batch, coil, [slice/time], height, width, complex=2)
- acs_image = self.estimate_acs_image(sample)
- # Shape (batch, [slice/time], height, width)
- acs_image_rss = T.root_sum_of_squares(acs_image, dim=self.coil_dim)
- # Shape (batch, 1, [slice/time], height, width, 1)
- acs_image_rss = acs_image_rss.unsqueeze(self.coil_dim).unsqueeze(self.complex_dim)
- # Shape (batch, coil, [slice/time], height, width, complex=2)
- sensitivity_map = T.safe_divide(acs_image, acs_image_rss)
- else:
- if sample[self.kspace_key].ndim > 5:
- raise NotImplementedError(
- "EstimateSensitivityMapModule is not yet implemented for "
- "Espirit sensitivity map estimation for 3D data."
- )
- sensitivity_map = self.espirit_calibrator(sample)
-
- sensitivity_map_norm = torch.sqrt(
- (sensitivity_map**2).sum(self.complex_dim).sum(self.coil_dim)
- ) # shape (batch, [slice/time], height, width)
- sensitivity_map_norm = sensitivity_map_norm.unsqueeze(self.coil_dim).unsqueeze(self.complex_dim)
-
- sample[TransformKey.SENSITIVITY_MAP] = T.safe_divide(sensitivity_map, sensitivity_map_norm)
- return sample
-
-
-class AddBooleanKeysModule(DirectModule):
- """Adds keys with boolean values to sample."""
-
- def __init__(self, keys: list[str], values: list[bool]) -> None:
- """Inits :class:`AddBooleanKeysModule`.
-
- Parameters
- ----------
- keys : list[str]
- A list of keys to be added.
- values : list[bool]
- A list of values corresponding to the keys.
- """
- super().__init__()
- self.keys = keys
- self.values = values
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Adds boolean keys to the input sample dictionary.
-
- Parameters
- ----------
- sample : dict[str, Any]
- The input sample dictionary.
-
- Returns
- -------
- dict[str, Any]
- The modified sample with added boolean keys.
- """
- for key, value in zip(self.keys, self.values):
- sample[key] = value
-
- return sample
-
-
-class CopyKeysModule(DirectModule):
- """Copy keys to a new name from the sample if present."""
-
- def __init__(self, keys: list[str], new_keys: list[str]) -> None:
- """Inits :class:`CopyKeysModule`.
-
- Parameters
- ----------
- keys: List[str]
- Key(s) to copy.
- new_keys: List[str]
- Key(s) to create.
- """
- super().__init__()
- self.keys = keys
- self.new_keys = new_keys
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`CopyKeysModule`.
-
- Parameters
- ----------
- sample: Dict[str, Any]
- Dictionary to look for keys and copy them with a new name.
-
- Returns
- -------
- Dict[str, Any]
- Dictionary with copied specified keys.
- """
- for key, new_key in zip(self.keys, self.new_keys):
- if key in sample:
- if isinstance(sample[key], np.ndarray):
- sample[new_key] = sample[key].copy() # Copy NumPy array
- elif isinstance(sample[key], torch.Tensor):
- sample[new_key] = sample[key].detach().clone() # Copy Torch tensor
- else:
- sample[new_key] = copy.deepcopy(sample[key])
- return sample
-
-
-class CompressCoilModule(DirectModule):
- """Compresses k-space coils using SVD."""
-
- def __init__(self, kspace_key: KspaceKey, num_coils: int) -> None:
- """Inits :class:`CompressCoilModule`.
-
- Parameters
- ----------
- kspace_key : KspaceKey
- K-space key.
- num_coils : int
- Number of coils to compress.
- """
- super().__init__()
- self.kspace_key = kspace_key
- self.num_coils = num_coils
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Performs coil compression to input k-space.
-
- Parameters
- ----------
- sample : dict[str, Any]
- Dict sample containing key `kspace_key`. Assumes coil dimension is first axis.
-
- Returns
- -------
- sample : dict[str, Any]
- Dict sample with `kspace_key` compressed to num_coils.
- """
- k_space = sample[self.kspace_key].clone() # shape (batch, coil, [slice/time], height, width, complex=2)
-
- if k_space.shape[1] <= self.num_coils:
- return sample
-
- ndim = k_space.ndim
-
- k_space = torch.view_as_complex(k_space)
-
- if ndim == 6: # If 3D sample reshape slice into batch dimension as sensitivities are computed 2D
- num_slice_or_time = k_space.shape[2]
- k_space = k_space.permute(0, 2, 1, 3, 4)
- k_space = k_space.reshape(k_space.shape[0] * num_slice_or_time, *k_space.shape[2:])
-
- shape = k_space.shape
-
- # Reshape the k-space data to combine spatial dimensions
- k_space_reshaped = k_space.reshape(shape[0], shape[1], -1)
-
- # Compute the coil combination matrix using Singular Value Decomposition (SVD)
- U, _, _ = torch.linalg.svd(k_space_reshaped, full_matrices=False)
-
- # Select the top ncoils_new singular vectors from the decomposition
- U_new = U[:, :, : self.num_coils]
-
- # Perform coil compression
- compressed_k_space = torch.matmul(U_new.transpose(1, 2), k_space_reshaped)
-
- # Reshape the compressed k-space back to its original shape
- compressed_k_space = compressed_k_space.reshape(shape[0], self.num_coils, *shape[2:])
-
- if ndim == 6:
- compressed_k_space = compressed_k_space.reshape(
- shape[0] // num_slice_or_time, num_slice_or_time, self.num_coils, *shape[2:]
- ).permute(0, 2, 1, 3, 4)
-
- compressed_k_space = torch.view_as_real(compressed_k_space)
- sample[self.kspace_key] = compressed_k_space # shape (batch, new coil, [slice/time], height, width, complex=2)
-
- return sample
-
-
-class DeleteKeysModule(DirectModule):
- """Remove keys from the sample if present."""
-
- def __init__(self, keys: list[str]) -> None:
- """Inits :class:`DeleteKeys`.
-
- Parameters
- ----------
- keys: list[str]
- Key(s) to delete.
- """
- super().__init__()
- self.keys = keys
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`DeleteKeys`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dictionary to look for keys and remove them.
-
- Returns
- -------
- dict[str, Any]
- Dictionary with deleted specified keys.
- """
- for key in self.keys:
- if key in sample:
- del sample[key]
-
- return sample
-
-
-class IndexSelectionMode(DirectEnum):
- RANDOM = "random"
- CUSTOM = "custom"
- RANGE = "range"
-
-
-class IndexSelectionModule(DirectModule):
- """Randomly selects indices from the sample.
-
- Parameters
- ----------
- key: TransformKey
- Key to select indices from.
- mode: IndexSelectionMode
- Mode of index selection. Can be IndexSelectionMode.RANDOM, IndexSelectionMode.CUSTOM or
- IndexSelectionMode.RANGE. Default: IndexSelectionMode.CUSTOM.
- num_indices: int
- Number of indices to select.
- out_key: TransformKey, optional
- Key to store the selected indices. If None, the indices are stored in the same key.
- Default: None.
- index_dim: int
- Dimension along which to select indices. Default: 1.
- use_seed: bool
- If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
- the same mask every time. Default: True
- """
-
- def __init__(
- self,
- key: TransformKey,
- mode: IndexSelectionMode = IndexSelectionMode.CUSTOM,
- indices: Optional[list[int]] = None,
- num_indices: Optional[int] = None,
- out_key: Optional[TransformKey] = None,
- index_dim: int = 0,
- use_seed: bool = True,
- ) -> None:
- """Inits :class:`IndexSelection`.
-
- Parameters
- ----------
- key: TransformKey
- Key to select indices from.
- mode: IndexSelectionMode
- Mode of index selection. Can be IndexSelectionMode.RANDOM, IndexSelectionMode.CUSTOM or
- IndexSelectionMode.RANGE. Default: IndexSelectionMode.CUSTOM.
- indices: list[int], optional
- List of indices to select if mode is IndexSelectionMode.CUSTOM or range if mode is
- IndexSelectionMode.RANGE. Default: None.
- num_indices: int
- Number of indices to select if mode is IndexSelectionMode.RANDOM. Default: None.
- out_key: TransformKey, optional
- Key to store the selected indices. If None, the indices are stored in the same key.
- Default: None.
- index_dim: int
- Dimension along which to select indices. Default: 1.
- use_seed: bool
- If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
- the same mask every time. Default: True
- """
- super().__init__()
- self.key = key
- self.out_key = out_key if out_key is not None else key
- self.mode = mode
- self.indices = indices
- self.num_indices = num_indices
- self.index_dim = index_dim
- self.use_seed = use_seed
- self.rng = np.random.RandomState()
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`IndexSelection`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dictionary to look for key and select indices from.
-
- Returns
- -------
- dict[str, Any]
- Dictionary with randomly selected indices.
- """
- if self.key not in sample:
- return sample
-
- if self.mode == IndexSelectionMode.RANDOM:
- seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"])))
- with temp_seed(self.rng, seed):
- num_to_keep = max(min(self.num_indices, sample[self.key].shape[self.index_dim]), 1)
- start = self.rng.randint(0, sample[self.key].shape[self.index_dim] - num_to_keep)
- keep_indices = torch.arange(start, start + num_to_keep, device=sample[self.key].device)
- else:
- if self.mode == IndexSelectionMode.CUSTOM:
- keep_indices = torch.tensor(
- [idx for idx in self.indices if np.abs(idx) < sample[self.key].shape[self.index_dim]],
- device=sample[self.key].device,
- )
- else:
- keep_indices = torch.arange(self.indices[0], self.indices[1], device=sample[self.key].device)
- num_to_keep = len(keep_indices)
-
- sample[self.out_key] = sample[self.key].index_select(self.index_dim, keep_indices)
-
- if num_to_keep == 1:
- sample[self.out_key] = sample[self.out_key].squeeze(self.index_dim)
-
- return sample
-
-
-class DropIndexModule(DirectModule):
- """Drop indices from the sample.
-
- Parameters
- ----------
- keys: list[TransformKey]
- Key(s) to drop indices from.
- index: int
- Index to drop.
- index_dim: int, list[int]
- Dimension(s) along which to drop indices. If a list, must have the same length as `keys`. Default: 1.
- store_deleted_keys: list[TransformKey], optional
- Key(s) to store the deleted indices. If None, the deleted indices are not stored. If the length does not
- match `keys`, the remaining keys are set to None. Default: None.
- """
-
- def __init__(
- self,
- keys: list[TransformKey],
- index: int,
- index_dim: int | list[int] = 1,
- store_deleted_keys: Optional[list[TransformKey]] = None,
- ) -> None:
- """Inits :class:`DropIndexModule`.
-
- Parameters
- ----------
- keys: list[TransformKey]
- Key(s) to drop indices from.
- index: int
- Index to drop.
- index_dim: int, list[int]
- Dimension(s) along which to drop indices. If a list, must have the same length as `keys`. Default: 1.
- store_deleted_keys: list[TransformKey], optional
- Key(s) to store the deleted indices. If None, the deleted indices are not stored. If the length does not
- match `keys`, the remaining keys are set to None. Default: None.
- """
- super().__init__()
- self.keys = keys
- self.index = index
- self.index_dim = [index_dim] * len(keys) if isinstance(index_dim, int) else index_dim
- self.store_deleted_keys = store_deleted_keys
- if self.store_deleted_keys is not None and len(keys) > len(self.store_deleted_keys):
- self.store_deleted_keys = store_deleted_keys + [None] * (len(keys) - len(self.store_deleted_keys))
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`DropIndexModule`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dictionary to look for key and drop indices from.
-
- Returns
- -------
- dict[str, Any]
- Dictionary with dropped index.
- """
-
- for i, key in enumerate(self.keys):
- if key not in sample:
- continue
- # This might be helpful, for instance, in case a single mask is used for all time frames
- if sample[key].shape[self.index_dim[i]] == 1:
- continue
- if self.store_deleted_keys is not None:
- deleted_key = self.store_deleted_keys[i]
- if deleted_key:
- sample[deleted_key] = sample[key].index_select(
- self.index_dim[i],
- torch.tensor(
- [idx for idx in range(sample[key].shape[self.index_dim[i]]) if idx == self.index],
- device=sample[key].device,
- ),
- )
- sample[key] = sample[key].index_select(
- self.index_dim[i],
- torch.tensor(
- [idx for idx in range(sample[key].shape[self.index_dim[i]]) if idx != self.index],
- device=sample[key].device,
- ),
- )
-
- return sample
-
-
-class SqueezeKeyModule(DirectModule):
- """Squeeze the specified key(s) in the sample.
-
- Parameters
- ----------
- keys: TransformKey
- Key(s) to squeeze.
- dim: int
- Dimension to squeeze.
- """
-
- def __init__(self, keys: TransformKey, dim: int) -> None:
- """Inits :class:`SqueezeKeyModule`.
-
- Parameters
- ----------
- keys: TransformKey
- Key(s) to squeeze.
- dim: int
- Dimension to squeeze.
- """
- super().__init__()
- self.keys = keys
- self.dim = dim
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`SqueezeKeyModule`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dictionary to look for keys to squeeze.
-
- Returns
- -------
- dict[str, Any]
- Dictionary with squeezed specified keys.
- """
- for key in self.keys:
- if key in sample:
- sample[key] = sample[key].squeeze(self.dim)
- return sample
-
-
-class RenameKeysModule(DirectModule):
- """Rename keys from the sample if present."""
-
- def __init__(self, old_keys: list[str], new_keys: list[str]) -> None:
- """Inits :class:`RenameKeys`.
-
- Parameters
- ----------
- old_keys: list[str]
- Key(s) to rename.
- new_keys: list[str]
- Key(s) to replace old keys.
- """
- super().__init__()
- self.old_keys = old_keys
- self.new_keys = new_keys
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`RenameKeys`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dictionary to look for keys and rename them.
-
- Returns
- -------
- dict[str, Any]
- Dictionary with renamed specified keys.
- """
- for old_key, new_key in zip(self.old_keys, self.new_keys):
- if old_key in sample:
- sample[new_key] = sample.pop(old_key)
-
- return sample
-
-
-class PadCoilDimensionModule(DirectModule):
- """Pad the coils by zeros to a given number of coils.
-
- Useful if you want to collate volumes with different coil dimension.
- """
-
- def __init__(
- self,
- pad_coils: Optional[int] = None,
- key: str = "masked_kspace",
- coil_dim: int = 1,
- ) -> None:
- """Inits :class:`PadCoilDimensionModule`.
-
- Parameters
- ----------
- pad_coils: int, optional
- Number of coils to pad to. Default: None.
- key: str
- Key to pad in sample. Default: "masked_kspace".
- coil_dim: int
- Coil dimension along which the pad will be done. Default: 0.
- """
- super().__init__()
- self.num_coils = pad_coils
- self.key = key
- self.coil_dim = coil_dim
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`PadCoilDimensionModule`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Dictionary with key `self.key`.
-
- Returns
- -------
- sample: dict[str, Any]
- Dictionary with padded coils of sample[self.key] if self.num_coils is not None.
- """
- if not self.num_coils:
- return sample
-
- if self.key not in sample:
- return sample
-
- data = sample[self.key]
-
- curr_num_coils = data.shape[self.coil_dim]
- if curr_num_coils > self.num_coils:
- raise ValueError(
- f"Tried to pad to {self.num_coils} coils, but already have {curr_num_coils} for "
- f"{sample['filename']}."
- )
- if curr_num_coils == self.num_coils:
- return sample
-
- shape = data.shape
- num_coils = shape[self.coil_dim]
- padding_data_shape = list(shape).copy()
- padding_data_shape[self.coil_dim] = max(self.num_coils - num_coils, 0)
- zeros = torch.zeros(padding_data_shape, dtype=data.dtype, device=data.device)
- sample[self.key] = torch.cat([zeros, data], dim=self.coil_dim)
-
- return sample
-
-
-class ComputeScalingFactorModule(DirectModule):
- """Calculates scaling factor.
-
- Scaling factor is for the input data based on either to the percentile or to the maximum of `normalize_key`.
- """
-
- def __init__(
- self,
- normalize_key: Union[None, TransformKey] = TransformKey.MASKED_KSPACE,
- percentile: Union[None, float] = 0.99,
- scaling_factor_key: TransformKey = TransformKey.SCALING_FACTOR,
- ) -> None:
- """Inits :class:`ComputeScalingFactorModule`.
-
- Parameters
- ----------
- normalize_key : TransformKey or None
- Key name to compute the data for. If the maximum has to be computed on the ACS, ensure the reconstruction
- on the ACS is available (typically `body_coil_image`). Default: "masked_kspace".
- percentile : float or None
- Rescale data with the given percentile. If None, the division is done by the maximum. Default: 0.99.
- scaling_factor_key : TransformKey
- Name of how the scaling factor will be stored. Default: "scaling_factor".
- """
- super().__init__()
- self.normalize_key = normalize_key
- self.percentile = percentile
- self.scaling_factor_key = scaling_factor_key
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`ComputeScalingFactorModule`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Sample with key `normalize_key` to compute scaling_factor.
-
- Returns
- -------
- sample: dict[str, Any]
- Sample with key `scaling_factor_key`.
- """
- if self.normalize_key == "scaling_factor": # This is a real-valued given number
- scaling_factor = sample["scaling_factor"]
- elif not self.normalize_key:
- kspace = sample["masked_kspace"]
- scaling_factor = torch.tensor([1.0] * kspace.size(0), device=kspace.device, dtype=kspace.dtype)
- else:
- data = sample[self.normalize_key]
- scaling_factor: Union[list, torch.Tensor] = []
- # Compute the maximum and scale the input
- if self.percentile:
- for _ in range(data.size(0)):
- # Used in case the k-space is padded (e.g. for batches)
- non_padded_coil_data = data[_][data[_].sum(dim=tuple(range(1, data[_].ndim))).bool()]
- tview = -1.0 * T.modulus(non_padded_coil_data).view(-1)
- s, _ = torch.kthvalue(tview, int((1 - self.percentile) * tview.size()[0]) + 1)
- scaling_factor += [-1.0 * s]
- scaling_factor = torch.tensor(scaling_factor, dtype=data.dtype, device=data.device)
- else:
- scaling_factor = T.modulus(data).amax(dim=list(range(data.ndim))[1:-1])
- sample[self.scaling_factor_key] = scaling_factor
- return sample
-
-
-class NormalizeModule(DirectModule):
- """Normalize the input data."""
-
- def __init__(
- self,
- scaling_factor_key: TransformKey = TransformKey.SCALING_FACTOR,
- keys_to_normalize: Optional[list[TransformKey]] = None,
- ) -> None:
- """Inits :class:`NormalizeModule`.
-
- Parameters
- ----------
- scaling_factor_key : TransformKey
- Name of scaling factor key expected in sample. Default: 'scaling_factor'.
- """
- super().__init__()
- self.scaling_factor_key = scaling_factor_key
-
- self.keys_to_normalize = (
- [
- "masked_kspace",
- "target",
- "kspace",
- "body_coil_image", # sensitivity_map does not require normalization.
- "initial_image",
- "initial_kspace",
- ]
- if keys_to_normalize is None
- else keys_to_normalize
- )
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`NormalizeModule`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Sample to normalize.
-
- Returns
- -------
- sample: dict[str, Any]
- Sample with normalized values if their respective key is in `keys_to_normalize` and key
- `scaling_factor_key` exists in sample.
- """
- scaling_factor = sample.get(self.scaling_factor_key, None)
- # Normalize data
- if scaling_factor is not None:
- for key in sample.keys():
- if key not in self.keys_to_normalize:
- continue
- sample[key] = T.safe_divide(
- sample[key],
- scaling_factor.reshape(-1, *[1 for _ in range(sample[key].ndim - 1)]),
- )
-
- sample["scaling_diff"] = 0.0
- return sample
-
-
-class WhitenDataModule(DirectModule):
- """Whitens complex data Module."""
-
- def __init__(self, epsilon: float = 1e-10, key: str = "complex_image") -> None:
- """Inits :class:`WhitenDataModule`.
-
- Parameters
- ----------
- epsilon: float
- Default: 1e-10.
- key: str
- Key to whiten. Default: "complex_image".
- """
- super().__init__()
- self.epsilon = epsilon
- self.key = key
-
- def complex_whiten(self, complex_image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Whiten complex image.
-
- Parameters
- ----------
- complex_image: torch.Tensor
- Complex image tensor to whiten.
-
- Returns
- -------
- mean, std, whitened_image: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- """
- # From: https://github.com/facebookresearch/fastMRI
- # blob/da1528585061dfbe2e91ebbe99a5d4841a5c3f43/banding_removal/fastmri/data/transforms.py#L464 # noqa
- real = complex_image[..., 0]
- imag = complex_image[..., 1]
-
- # Center around mean.
- mean = complex_image.mean()
- centered_complex_image = complex_image - mean
-
- # Determine covariance between real and imaginary.
- n_elements = real.nelement()
- real_real = (real.mul(real).sum() - real.mean().mul(real.mean())) / n_elements
- real_imag = (real.mul(imag).sum() - real.mean().mul(imag.mean())) / n_elements
- imag_imag = (imag.mul(imag).sum() - imag.mean().mul(imag.mean())) / n_elements
- eig_input = torch.Tensor([[real_real, real_imag], [real_imag, imag_imag]])
-
- # Remove correlation by rotating around covariance eigenvectors.
- eig_values, eig_vecs = torch.linalg.eig(eig_input)
-
- # Scale by eigenvalues for unit variance.
- std = (eig_values.real + self.epsilon).sqrt()
- whitened_image = torch.matmul(centered_complex_image, eig_vecs.real) / std
-
- return mean, std, whitened_image
-
- def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Forward pass of :class:`WhitenDataModule`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Sample with key `key`.
-
- Returns
- -------
- sample: dict[str, Any]
- Sample with value of `key` whitened.
- """
- _, _, whitened_image = self.complex_whiten(sample[self.key])
- sample[self.key] = whitened_image
- return sample
-
-
-class AddTargetAcceleration(DirectTransform):
- """This will replace the acceleration factor in the sample with the target acceleration factor."""
-
- def __init__(self, target_acceleration: float):
- super().__init__()
- self.target_acceleration = target_acceleration
-
- def __call__(self, sample: dict[str, Any]):
- sample["acceleration"][:] = self.target_acceleration
- return sample
-
-
-class ModuleWrapper:
- class SubWrapper:
- def __init__(self, transform: Callable, toggle_dims: bool) -> None:
- self.toggle_dims = toggle_dims
- self._transform = transform
-
- def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
- if self.toggle_dims:
- for k, v in sample.items():
- if isinstance(v, (torch.Tensor, np.ndarray)):
- sample[k] = v[None]
- else:
- sample[k] = [v]
-
- sample = self._transform.forward(sample)
-
- if self.toggle_dims:
- for k, v in sample.items():
- if isinstance(v, (torch.Tensor, np.ndarray)):
- sample[k] = v.squeeze(0)
- else:
- sample[k] = v[0]
-
- return sample
-
- def __repr__(self) -> str:
- return self._transform.__repr__()
-
- def __init__(self, module: Callable, toggle_dims: bool) -> None:
- self._module = module
- self.toggle_dims = toggle_dims
-
- def __call__(self, *args, **kwargs) -> SubWrapper:
- return self.SubWrapper(self._module(*args, **kwargs), toggle_dims=self.toggle_dims)
-
-
-ApplyMask = ModuleWrapper(ApplyMaskModule, toggle_dims=False)
-ComputeImage = ModuleWrapper(ComputeImageModule, toggle_dims=True)
-EstimateSensitivityMap = ModuleWrapper(EstimateSensitivityMapModule, toggle_dims=True)
-CopyKeys = ModuleWrapper(CopyKeysModule, toggle_dims=False)
-DeleteKeys = ModuleWrapper(DeleteKeysModule, toggle_dims=False)
-RenameKeys = ModuleWrapper(RenameKeysModule, toggle_dims=False)
-IndexSelection = ModuleWrapper(IndexSelectionModule, toggle_dims=False)
-DropIndex = ModuleWrapper(DropIndexModule, toggle_dims=False)
-SqueezeKey = ModuleWrapper(SqueezeKeyModule, toggle_dims=False)
-CompressCoil = ModuleWrapper(CompressCoilModule, toggle_dims=True)
-PadCoilDimension = ModuleWrapper(PadCoilDimensionModule, toggle_dims=True)
-ComputeScalingFactor = ModuleWrapper(ComputeScalingFactorModule, toggle_dims=True)
-Normalize = ModuleWrapper(NormalizeModule, toggle_dims=False)
-WhitenData = ModuleWrapper(WhitenDataModule, toggle_dims=False)
-GaussianMaskSplitter = ModuleWrapper(GaussianMaskSplitterModule, toggle_dims=True)
-UniformMaskSplitter = ModuleWrapper(UniformMaskSplitterModule, toggle_dims=True)
-Displacement = ModuleWrapper(DisplacementModule, toggle_dims=True)
-RandomElasticDeformation = ModuleWrapper(RandomElasticDeformationModule, toggle_dims=True)
-
-
-class ToTensor(DirectTransform):
- """Transforms all np.array-like values in sample to torch.tensors."""
-
- def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
- """Calls :class:`ToTensor`.
-
- Parameters
- ----------
- sample: dict[str, Any]
- Contains key 'kspace' with value a np.array of shape (coil, height, width) (2D)
- or (coil, slice, height, width) (3D)
-
- Returns
- -------
- sample: dict[str, Any]
- Contains key 'kspace' with value a torch.Tensor of shape (coil, height, width) (2D)
- or (coil, slice, height, width) (3D)
- """
-
- ndim = sample["kspace"].ndim - 1
-
- if ndim not in [2, 3]:
- raise ValueError(f"Can only cast 2D and 3D data (+coil) to tensor. Got {ndim}.")
-
- # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2)
- sample["kspace"] = T.to_tensor(sample["kspace"]).float()
- # Sensitivity maps are not necessarily available in the dataset.
- if "initial_kspace" in sample:
- # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2)
- sample["initial_kspace"] = T.to_tensor(sample["initial_kspace"]).float()
- if "initial_image" in sample:
- # Shape: 2D: (height, width), 3D: (slice, height, width)
- sample["initial_image"] = T.to_tensor(sample["initial_image"]).float()
-
- if "sensitivity_map" in sample:
- # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2)
- sample["sensitivity_map"] = T.to_tensor(sample["sensitivity_map"]).float()
- if "target" in sample:
- # Shape: 2D: (coil, height, width), 3D: (coil, slice, height, width)
- sample["target"] = torch.from_numpy(sample["target"]).float()
- if "sampling_mask" in sample:
- sample["sampling_mask"] = torch.from_numpy(sample["sampling_mask"]).bool()
- if "acs_mask" in sample:
- sample["acs_mask"] = torch.from_numpy(sample["acs_mask"]).bool()
- if "scaling_factor" in sample:
- sample["scaling_factor"] = torch.tensor(sample["scaling_factor"]).float()
- if "loglikelihood_scaling" in sample:
- # Shape: (coil, )
- sample["loglikelihood_scaling"] = torch.from_numpy(np.asarray(sample["loglikelihood_scaling"])).float()
-
- return sample
-
-
-class RegistrationSimulateReferenceType(DirectEnum):
- FROM_KEY = "from_key"
- ELASTIC = "elastic"
-
-
-# pylint: disable=too-many-arguments
-def build_supervised_mri_transforms(
- forward_operator: Callable,
- backward_operator: Callable,
- mask_func: Optional[Callable],
- target_acceleration: Optional[float] = None,
- crop: Optional[Union[tuple[int, int], str]] = None,
- crop_type: Optional[str] = "uniform",
- rescale: Optional[Union[tuple[int, int], list[int]]] = None,
- rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST,
- rescale_2d_if_3d: Optional[bool] = False,
- pad: Optional[Union[tuple[int, int], list[int]]] = None,
- image_center_crop: bool = True,
- random_rotation_degrees: Optional[Sequence[int]] = (-90, 90),
- random_rotation_probability: float = 0.0,
- random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM,
- random_flip_probability: float = 0.0,
- random_reverse_probability: float = 0.0,
- padding_eps: float = 0.0001,
- estimate_body_coil_image: bool = False,
- estimate_sensitivity_maps: bool = True,
- sensitivity_maps_type: SensitivityMapType = SensitivityMapType.RSS_ESTIMATE,
- sensitivity_maps_gaussian: Optional[float] = None,
- sensitivity_maps_espirit_threshold: Optional[float] = 0.05,
- sensitivity_maps_espirit_kernel_size: Optional[int] = 6,
- sensitivity_maps_espirit_crop: Optional[float] = 0.95,
- sensitivity_maps_espirit_max_iters: Optional[int] = 30,
- use_acs_as_mask: bool = False,
- delete_acs: bool = True,
- delete_kspace: bool = True,
- image_recon_type: ReconstructionType = ReconstructionType.RSS,
- compress_coils: Optional[int] = None,
- pad_coils: Optional[int] = None,
- scaling_key: TransformKey = TransformKey.MASKED_KSPACE,
- scale_percentile: Optional[float] = 0.99,
- registration: bool = False,
- registration_simulate_reference: Optional[RegistrationSimulateReferenceType] = None,
- registration_simulate_elastic_sigma: float = 3.0,
- registration_simulate_elastic_points: int = 3,
- registration_simulate_elastic_rotate: float = 0.0,
- registration_simulate_elastic_zoom: float = 0.0,
- registration_estimate_displacement: bool = True,
- registration_simulate_reference_from_key_index: int = 0,
- registration_moving_key: TransformKey = TransformKey.TARGET,
- demons_filter_type: DemonsFilterType = DemonsFilterType.SYMMETRIC_FORCES,
- demons_num_iterations: int = 100,
- demons_smooth_displacement_field: bool = True,
- demons_standard_deviations: float = 1.5,
- demons_intensity_difference_threshold: Optional[float] = None,
- demons_maximum_rms_error: Optional[float] = None,
- use_seed: bool = True,
-) -> DirectTransform:
- r"""Builds supervised MRI transforms.
-
- More specifically, the following transformations are applied:
-
- * Converts input to (complex-valued) tensor.
- * Applies k-space (center) crop if requested.
- * Applies k-space rescaling if requested.
- * Applies k-space padding if requested.
- * Applies random augmentations (rotation, flip, reverse) if requested.
- * Adds a sampling mask if `mask_func` is defined.
- * Compreses the coil dimension if requested.
- * Pads the coil dimension if requested.
- * Adds coil sensitivities and / or the body coil_image
- * Masks the fully sampled k-space, if there is a mask function or a mask in the sample.
- * Computes a scaling factor based on the masked k-space and normalizes data.
- * Computes a target (image).
- * Deletes the acs mask and the fully sampled k-space if requested.
-
- Parameters
- ----------
- forward_operator : Callable
- The forward operator, e.g. some form of FFT (centered or uncentered).
- backward_operator : Callable
- The backward operator, e.g. some form of inverse FFT (centered or uncentered).
- mask_func : Callable or None
- A function which creates a sampling mask of the appropriate shape.
- target_acceleration : float, optional
- Target acceleration factor. Default: None.
- crop : tuple[int, int] or str, Optional
- If not None, this will transform the "kspace" to an image domain, crop it, and transform it back.
- If a tuple of integers is given then it will crop the backprojected kspace to that size. If
- "reconstruction_size" is given, then it will crop the backprojected kspace according to it, but
- a key "reconstruction_size" must be present in the sample. Default: None.
- crop_type : Optional[str]
- Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform".
- rescale : tuple or list, optional
- If not None, this will transform the "kspace" to the image domain, rescale it, and transform it back.
- Must correspond to (height, width). This is ignored if `rescale` is None. Default: None.
- It is not recommended to be used in combination with `crop`.
- rescale_mode : RescaleMode
- Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR,
- RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are
- supported for 2D or 3D data. Default: RescaleMode.NEAREST.
- rescale_2d_if_3d : bool, optional
- If True and k-space data is 3D, rescaling will be done only on the height
- and width dimensions, by combining the slice/time dimension with the batch dimension.
- This is ignored if `rescale` is None. Default: False.
- pad : tuple or list, optional
- If not None, this will zero-pad the "kspace" to the given size. Must correspond to (height, width)
- or (slice/time, height, width). Default: None.
- image_center_crop : bool
- If True the backprojected kspace will be cropped around the center, otherwise randomly.
- This will be ignored if `crop` is None. Default: True.
- random_rotation_degrees : Sequence[int], optional
- Default: (-90, 90).
- random_rotation_probability : float, optional
- If greater than 0.0, random rotations will be applied of `random_rotation_degrees` degrees, with probability
- `random_rotation_probability`. Default: 0.0.
- random_flip_type : RandomFlipType, optional
- Default: RandomFlipType.RANDOM.
- random_flip_probability : float, optional
- If greater than 0.0, random rotation of `random_flip_type` type, with probability `random_flip_probability`.
- Default: 0.0.
- random_reverse_probability : float
- If greater than 0.0, will perform random reversion along the time or slice dimension (2) with probability
- `random_reverse_probability`. Default: 0.0.
- padding_eps: float
- Padding epsilon. Default: 0.0001.
- estimate_body_coil_image : bool
- Estimate body coil image. Default: False.
- estimate_sensitivity_maps : bool
- Estimate sensitivity maps using the acs region. Default: True.
- sensitivity_maps_type: sensitivity_maps_type
- Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or SensitivityMapType.ESPIRIT.
- Will be ignored if `estimate_sensitivity_maps` is False. Default: SensitivityMapType.RSS_ESTIMATE.
- sensitivity_maps_gaussian : float
- Optional sigma for gaussian weighting of sensitivity map.
- sensitivity_maps_espirit_threshold : float, optional
- Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
- Default: 0.05.
- sensitivity_maps_espirit_kernel_size : int, optional
- Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 6.
- sensitivity_maps_espirit_crop : float, optional
- Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 0.95.
- sensitivity_maps_espirit_max_iters : int, optional
- Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30.
- use_acs_as_mask : bool
- If True, will use the acs region as the mask. Default: False.
- delete_acs : bool
- If True will delete key `acs_mask`. Default: True.
- delete_kspace : bool
- If True will delete key `kspace` (fully sampled k-space). Default: True.
- image_recon_type : ReconstructionType
- Type to reconstruct target image. Default: ReconstructionType.RSS.
- compress_coils : int, optional
- Number of coils to compress input k-space. It is not recommended to be used in combination with `pad_coils`.
- Default: None.
- pad_coils : int
- Number of coils to pad data to.
- scaling_key : TransformKey
- Key in sample to scale scalable items in sample. Default: TransformKey.MASKED_KSPACE.
- scale_percentile : float, optional
- Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99.
- registration : bool
- If True, will compute a displacement field between the target and the moving image. Default: False.
- registration_simulate_reference : RegistrationSimulateReferenceType
- If not None, will simulate a reference image for displacement field computation. Otherwise, this expects a key
- in the sample. Can be RegistrationSimulateReferenceType.FROM_KEY or RegistrationSimulateReferenceType.ELASTIC.
- Default: None.
- registration_simulate_elastic_sigma : float
- Standard deviation for the elastic simulation. Default: 3.0.
- registration_simulate_elastic_points : int
- Number of points for the elastic simulation. Default: 3.
- registration_simulate_elastic_rotate : float
- Rotation for the elastic simulation. Default: 0.0.
- registration_estimate_displacement : bool
- If True, will estimate the displacement field between the target and the moving image using the
- demons algorithm. Default: True
- registration_simulate_elastic_zoom : float
- Zoom for the elastic simulation. Default: 0.0.
- registration_simulate_reference_from_key_index : int
- Index to drop from the key to simulate the reference image. Default: 0.
- demons_filter_type : DemonsFilterType
- Type of filter to apply to the displacement field. Default: DemonsFilterType.SYMMETRIC_FORCES.
- demons_num_iterations : int
- Number of iterations for the demons algorithm. Default: 100.
- demons_smooth_displacement_field : bool
- If True, will smooth the displacement field. Default: True.
- demons_standard_deviations : float
- Standard deviation for the smoothing of the displacement field. Default: 1.5.
- demons_intensity_difference_threshold : float, optional
- Intensity difference threshold for the demons algorithm. Default: None.
- demons_maximum_rms_error : float, optional
- Maximum RMS error for the demons algorithm. Default: None.
- use_seed : bool
- If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
- the same mask every time. Default: True.
-
- Returns
- -------
- DirectTransform
- An MRI transformation object.
- """
- mri_transforms: list[Callable] = [ToTensor()]
- if crop:
- mri_transforms += [
- CropKspace(
- crop=crop,
- forward_operator=forward_operator,
- backward_operator=backward_operator,
- image_space_center_crop=image_center_crop,
- random_crop_sampler_type=crop_type,
- random_crop_sampler_use_seed=use_seed,
- )
- ]
- if rescale:
- mri_transforms += [
- RescaleKspace(
- shape=rescale,
- forward_operator=forward_operator,
- backward_operator=backward_operator,
- rescale_mode=rescale_mode,
- rescale_2d_if_3d=rescale_2d_if_3d,
- kspace_key=KspaceKey.KSPACE,
- )
- ]
- if pad:
- mri_transforms += [
- PadKspace(
- pad_shape=pad,
- forward_operator=forward_operator,
- backward_operator=backward_operator,
- kspace_key=KspaceKey.KSPACE,
- )
- ]
- if random_rotation_probability > 0.0:
- mri_transforms += [
- RandomRotation(
- degrees=random_rotation_degrees,
- p=random_rotation_probability,
- keys_to_rotate=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP),
- )
- ]
- if random_flip_probability > 0.0:
- mri_transforms += [
- RandomFlip(
- flip=random_flip_type,
- p=random_flip_probability,
- keys_to_flip=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP),
- )
- ]
- if random_reverse_probability > 0.0:
- mri_transforms += [
- RandomReverse(
- p=random_reverse_probability,
- keys_to_reverse=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP),
- )
- ]
- if padding_eps > 0.0:
- mri_transforms += [
- ComputeZeroPadding(KspaceKey.KSPACE, "padding", padding_eps),
- ApplyZeroPadding(KspaceKey.KSPACE, "padding"),
- ]
- if mask_func:
- mri_transforms += [
- CreateSamplingMask(
- mask_func,
- shape=(None if (isinstance(crop, str)) else crop),
- use_seed=use_seed,
- return_acs=estimate_sensitivity_maps,
- ),
- ]
- if use_acs_as_mask:
- mri_transforms += [CopyKeys(keys=[TransformKey.ACS_MASK], new_keys=[TransformKey.SAMPLING_MASK])]
- if target_acceleration:
- mri_transforms += [AddTargetAcceleration(target_acceleration)]
- if compress_coils:
- mri_transforms += [CompressCoil(num_coils=compress_coils, kspace_key=KspaceKey.KSPACE)]
- if pad_coils:
- mri_transforms += [PadCoilDimension(pad_coils=pad_coils, key=KspaceKey.KSPACE)]
-
- if estimate_body_coil_image and mask_func is not None:
- mri_transforms.append(EstimateBodyCoilImage(mask_func, backward_operator=backward_operator, use_seed=use_seed))
- mri_transforms += [
- ApplyMask(
- sampling_mask_key=TransformKey.ACS_MASK,
- input_kspace_key=KspaceKey.KSPACE,
- target_kspace_key=KspaceKey.ACS_KSPACE,
- ),
- ]
- if estimate_sensitivity_maps:
- mri_transforms += [
- EstimateSensitivityMap(
- kspace_key=KspaceKey.ACS_KSPACE,
- backward_operator=backward_operator,
- type_of_map=sensitivity_maps_type,
- gaussian_sigma=sensitivity_maps_gaussian,
- espirit_threshold=sensitivity_maps_espirit_threshold,
- espirit_kernel_size=sensitivity_maps_espirit_kernel_size,
- espirit_crop=sensitivity_maps_espirit_crop,
- espirit_max_iters=sensitivity_maps_espirit_max_iters,
- )
- ]
- mri_transforms += [
- ApplyMask(
- sampling_mask_key=TransformKey.SAMPLING_MASK,
- input_kspace_key=KspaceKey.KSPACE,
- target_kspace_key=KspaceKey.MASKED_KSPACE,
- ),
- ]
- if registration:
- if registration_simulate_reference is not None:
- mri_transforms += [
- DropIndex(
- keys=[
- TransformKey.KSPACE,
- TransformKey.ACS_KSPACE,
- TransformKey.MASKED_KSPACE,
- TransformKey.ACS_MASK,
- TransformKey.SAMPLING_MASK,
- TransformKey.PADDING,
- TransformKey.SENSITIVITY_MAP,
- ],
- index=registration_simulate_reference_from_key_index,
- index_dim=1,
- store_deleted_keys=[TransformKey.REFERENCE_KSPACE],
- )
- ]
- mri_transforms += [
- ComputeScalingFactor(
- normalize_key=scaling_key, percentile=scale_percentile, scaling_factor_key=TransformKey.SCALING_FACTOR
- ),
- Normalize(
- scaling_factor_key=TransformKey.SCALING_FACTOR,
- keys_to_normalize=[
- KspaceKey.ACS_KSPACE,
- KspaceKey.KSPACE,
- KspaceKey.MASKED_KSPACE,
- KspaceKey.REFERENCE_KSPACE,
- ], # Only these two keys are in the sample here
- ),
- ]
- mri_transforms += [
- ComputeImage(
- kspace_key=KspaceKey.KSPACE,
- target_key=TransformKey.TARGET,
- backward_operator=backward_operator,
- type_reconstruction=image_recon_type,
- )
- ]
- if registration:
- if registration_simulate_reference is not None:
- mri_transforms += [
- ComputeImage(
- kspace_key=KspaceKey.REFERENCE_KSPACE,
- target_key=TransformKey.REFERENCE_IMAGE,
- backward_operator=backward_operator,
- type_reconstruction=image_recon_type,
- ),
- SqueezeKey(keys=[TransformKey.REFERENCE_IMAGE], dim=0),
- ]
- if registration_simulate_reference == RegistrationSimulateReferenceType.ELASTIC:
- mri_transforms += [
- RandomElasticDeformation(
- image_key=TransformKey.REFERENCE_IMAGE,
- target_key=TransformKey.REFERENCE_IMAGE,
- use_seed=use_seed,
- sigma=registration_simulate_elastic_sigma,
- points=registration_simulate_elastic_points,
- rotate=registration_simulate_elastic_rotate,
- zoom=registration_simulate_elastic_zoom,
- )
- ]
- if registration_estimate_displacement:
- mri_transforms += [
- Displacement(
- transform_type=DisplacementTransformType.MULTISCALE_DEMONS,
- demons_filter_type=demons_filter_type,
- demons_num_iterations=demons_num_iterations,
- demons_smooth_displacement_field=demons_smooth_displacement_field,
- demons_standard_deviations=demons_standard_deviations,
- demons_intensity_difference_threshold=demons_intensity_difference_threshold,
- demons_maximum_rms_error=demons_maximum_rms_error,
- reference_image_key=TransformKey.REFERENCE_IMAGE,
- moving_image_key=registration_moving_key,
- )
- ]
- if delete_acs:
- mri_transforms += [DeleteKeys(keys=[TransformKey.ACS_MASK, KspaceKey.ACS_KSPACE])]
- if delete_kspace:
- mri_transforms += [DeleteKeys(keys=[KspaceKey.KSPACE])]
-
- return Compose(mri_transforms)
-
-
-class TransformsType(DirectEnum):
- SUPERVISED = "supervised"
- SSL_SSDU = "ssl_ssdu"
-
-
-# pylint: disable=too-many-arguments
-def build_mri_transforms(
- forward_operator: Callable,
- backward_operator: Callable,
- mask_func: Optional[Callable],
- target_acceleration: Optional[float] = None,
- crop: Optional[Union[tuple[int, int], str]] = None,
- crop_type: Optional[str] = "uniform",
- rescale: Optional[Union[tuple[int, int], list[int]]] = None,
- rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST,
- rescale_2d_if_3d: Optional[bool] = False,
- pad: Optional[Union[tuple[int, int], list[int]]] = None,
- image_center_crop: bool = True,
- random_rotation_degrees: Optional[Sequence[int]] = (-90, 90),
- random_rotation_probability: float = 0.0,
- random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM,
- random_flip_probability: float = 0.0,
- random_reverse_probability: float = 0.0,
- padding_eps: float = 0.0001,
- estimate_body_coil_image: bool = False,
- estimate_sensitivity_maps: bool = True,
- sensitivity_maps_type: SensitivityMapType = SensitivityMapType.RSS_ESTIMATE,
- sensitivity_maps_gaussian: Optional[float] = None,
- sensitivity_maps_espirit_threshold: Optional[float] = 0.05,
- sensitivity_maps_espirit_kernel_size: Optional[int] = 6,
- sensitivity_maps_espirit_crop: Optional[float] = 0.95,
- sensitivity_maps_espirit_max_iters: Optional[int] = 30,
- use_acs_as_mask: bool = False,
- delete_acs: bool = True,
- delete_kspace: bool = True,
- image_recon_type: ReconstructionType = ReconstructionType.RSS,
- compress_coils: Optional[int] = None,
- pad_coils: Optional[int] = None,
- scaling_key: TransformKey = TransformKey.MASKED_KSPACE,
- scale_percentile: Optional[float] = 0.99,
- registration: bool = False,
- registration_simulate_reference: Optional[RegistrationSimulateReferenceType] = None,
- registration_simulate_elastic_sigma: float = 3.0,
- registration_simulate_elastic_points: int = 3,
- registration_simulate_elastic_rotate: float = 0.0,
- registration_simulate_elastic_zoom: float = 0.0,
- registration_estimate_displacement: bool = True,
- registration_simulate_reference_from_key_index: int = 0,
- registration_moving_key: TransformKey = TransformKey.TARGET,
- demons_filter_type: DemonsFilterType = DemonsFilterType.SYMMETRIC_FORCES,
- demons_num_iterations: int = 100,
- demons_smooth_displacement_field: bool = True,
- demons_standard_deviations: float = 1.5,
- demons_intensity_difference_threshold: Optional[float] = None,
- demons_maximum_rms_error: Optional[float] = None,
- use_seed: bool = True,
- transforms_type: Optional[TransformsType] = TransformsType.SUPERVISED,
- mask_split_ratio: Union[float, list[float], tuple[float, ...]] = 0.4,
- mask_split_acs_region: Union[list[int], tuple[int, int]] = (0, 0),
- mask_split_keep_acs: Optional[bool] = False,
- mask_split_type: MaskSplitterType = MaskSplitterType.GAUSSIAN,
- mask_split_gaussian_std: float = 3.0,
- mask_split_half_direction: HalfSplitType = HalfSplitType.VERTICAL,
-) -> DirectTransform:
- r"""Build transforms for MRI.
-
- More specifically, the following transformations are applied:
-
- * Converts input to (complex-valued) tensor.
- * Applies k-space (center) crop if requested.
- * Applies k-space rescaling if requested.
- * Applies k-space padding if requested.
- * Applies random augmentations (rotation, flip, reverse) if requested.
- * Adds a sampling mask if `mask_func` is defined.
- * Compreses the coil dimension if requested.
- * Pads the coil dimension if requested.
- * Adds coil sensitivities and / or the body coil_image
- * Masks the fully sampled k-space, if there is a mask function or a mask in the sample.
- * Computes a scaling factor based on the masked k-space and normalizes data.
- * Computes a target (image).
- * Deletes the acs mask and the fully sampled k-space if requested.
- * Splits the mask if requested for self-supervised learning.
-
- Parameters
- ----------
- forward_operator : Callable
- The forward operator, e.g. some form of FFT (centered or uncentered).
- backward_operator : Callable
- The backward operator, e.g. some form of inverse FFT (centered or uncentered).
- mask_func : Callable or None
- A function which creates a sampling mask of the appropriate shape.
- target_acceleration : float, optional
- Target acceleration factor. Default: None.
- crop : tuple[int, int] or str, Optional
- If not None, this will transform the "kspace" to an image domain, crop it, and transform it back.
- If a tuple of integers is given then it will crop the backprojected kspace to that size. If
- "reconstruction_size" is given, then it will crop the backprojected kspace according to it, but
- a key "reconstruction_size" must be present in the sample. Default: None.
- crop_type : Optional[str]
- Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform".
- rescale : tuple or list, optional
- If not None, this will transform the "kspace" to the image domain, rescale it, and transform it back.
- Must correspond to (height, width). This is ignored if `rescale` is None. Default: None.
- It is not recommended to be used in combination with `crop`.
- rescale_mode : RescaleMode
- Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR,
- RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are
- supported for 2D or 3D data. Default: RescaleMode.NEAREST.
- rescale_2d_if_3d : bool, optional
- If True and k-space data is 3D, rescaling will be done only on the height
- and width dimensions, by combining the slice/time dimension with the batch dimension.
- This is ignored if `rescale` is None. Default: False.
- pad : tuple or list, optional
- If not None, this will zero-pad the "kspace" to the given size. Must correspond to (height, width)
- or (slice/time, height, width). Default: None.
- image_center_crop : bool
- If True the backprojected kspace will be cropped around the center, otherwise randomly.
- This will be ignored if `crop` is None. Default: True.
- random_rotation_degrees : Sequence[int], optional
- Default: (-90, 90).
- random_rotation_probability : float, optional
- If greater than 0.0, random rotations will be applied of `random_rotation_degrees` degrees, with probability
- `random_rotation_probability`. Default: 0.0.
- random_flip_type : RandomFlipType, optional
- Default: RandomFlipType.RANDOM.
- random_flip_probability : float, optional
- If greater than 0.0, random rotation of `random_flip_type` type, with probability `random_flip_probability`.
- Default: 0.0.
- random_reverse_probability : float
- If greater than 0.0, will perform random reversion along the time or slice dimension (2) with probability
- `random_reverse_probability`. Default: 0.0.
- padding_eps: float
- Padding epsilon. Default: 0.0001.
- estimate_body_coil_image : bool
- Estimate body coil image. Default: False.
- estimate_sensitivity_maps : bool
- Estimate sensitivity maps using the acs region. Default: True.
- sensitivity_maps_type: sensitivity_maps_type
- Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or SensitivityMapType.ESPIRIT.
- Will be ignored if `estimate_sensitivity_maps` is False. Default: SensitivityMapType.RSS_ESTIMATE.
- sensitivity_maps_gaussian : float
- Optional sigma for gaussian weighting of sensitivity map.
- sensitivity_maps_espirit_threshold : float, optional
- Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
- Default: 0.05.
- sensitivity_maps_espirit_kernel_size : int, optional
- Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 6.
- sensitivity_maps_espirit_crop : float, optional
- Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 0.95.
- sensitivity_maps_espirit_max_iters : int, optional
- Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30.
- use_acs_as_mask : bool
- If True, will use the acs region as the mask. Default: False.
- delete_acs : bool
- If True will delete key `acs_mask`. Default: True.
- delete_kspace : bool
- If True will delete key `kspace` (fully sampled k-space). Default: True.
- image_recon_type : ReconstructionType
- Type to reconstruct target image. Default: ReconstructionType.RSS.
- compress_coils : int, optional
- Number of coils to compress input k-space. It is not recommended to be used in combination with `pad_coils`.
- Default: None.
- pad_coils : int
- Number of coils to pad data to.
- scaling_key : TransformKey
- Key in sample to scale scalable items in sample. Default: TransformKey.MASKED_KSPACE.
- scale_percentile : float, optional
- Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99.
- registration : bool
- If True, will compute a displacement field between the target and the moving image. Default: False.
- registration_simulate_reference : RegistrationSimulateReferenceType
- If not None, will simulate a reference image for displacement field computation. Otherwise, this expects a key
- in the sample. Can be RegistrationSimulateReferenceType.FROM_KEY or RegistrationSimulateReferenceType.ELASTIC.
- Default: None.
- registration_simulate_elastic_sigma : float
- Standard deviation for the elastic simulation. Default: 3.0.
- registration_simulate_elastic_points : int
- Number of points for the elastic simulation. Default: 3.
- registration_simulate_elastic_rotate : float
- Rotation for the elastic simulation. Default: 0.0.
- registration_simulate_elastic_zoom : float
- Zoom for the elastic simulation. Default: 0.0.
- registration_estimate_displacement : bool
- If True, will estimate the displacement field between the target and the moving image using the
- demons algorithm. Default: True
- registration_simulate_reference_from_key_index : int
- Index to drop from the key to simulate the reference image. Default: 0.
- registration_moving_key : TransformKey
- Key in sample to compute displacement field from. Default: TransformKey.TARGET.
- demons_filter_type : DemonsFilterType
- Type of filter to apply to the displacement field. Default: DemonsFilterType.SYMMETRIC_FORCES.
- demons_num_iterations : int
- Number of iterations for the demons algorithm. Default: 100.
- demons_smooth_displacement_field : bool
- If True, will smooth the displacement field. Default: True.
- demons_standard_deviations : float
- Standard deviation for the smoothing of the displacement field. Default: 1.5.
- demons_intensity_difference_threshold : float, optional
- Intensity difference threshold for the demons algorithm. Default: None.
- demons_maximum_rms_error : float, optional
- Maximum RMS error for the demons algorithm. Default: None.
- use_seed : bool
- If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
- the same mask every time. Default: True.
- transforms_type : TransformsType, optional
- Can be `TransformsType.SUPERVISED` for supervised learning transforms or `TransformsType.SSL_SSDU` for
- self-supervised learning transforms. Default: `TransformsType.SUPERVISED`.
- mask_split_ratio : Union[float, list[float], tuple[float, ...]]
- The ratio(s) of the sampling mask splitting. If `transforms_type` is TransformsKey.SUPERVISED, this is ignored.
- mask_split_acs_region : Union[list[int], tuple[int, int]]
- A rectangle for the acs region that will be used in the input mask. This applies only if `transforms_type` is
- set to TransformsKey.SSL_SSDU. Default: (0, 0).
- mask_split_keep_acs : Optional[bool]
- If True, acs region according to the "acs_mask" of the sample will be used in both mask splits.
- This applies only if `transforms_type` is set to TransformsKey.SSL_SSDU. Default: False.
- mask_split_type : MaskSplitterType
- How the sampling mask will be split. Can be MaskSplitterType.UNIFORM, MaskSplitterType.GAUSSIAN, or
- MaskSplitterType.HALF. Default: MaskSplitterType.GAUSSIAN. This applies only if `transforms_type` is
- set to TransformsKey.SSL_SSDU. Default: MaskSplitterType.GAUSSIAN.
- mask_split_gaussian_std : float
- Standard deviation of gaussian mask splitting. This applies only if `transforms_type` is
- set to TransformsKey.SSL_SSDU. Ignored if `mask_split_type` is not set to MaskSplitterType.GAUSSIAN.
- Default: 3.0.
- mask_split_half_direction : HalfSplitType
- Split type if `mask_split_type` is `MaskSplitterType.HALF`. Can be `HalfSplitType.VERTICAL`,
- `HalfSplitType.HORIZONTAL`, `HalfSplitType.DIAGONAL_LEFT` or `HalfSplitType.DIAGONAL_RIGHT`.
- This applies only if `transforms_type` is set to `TransformsKey.SSL_SSDU`. Ignored if `mask_split_type` is not
- set to `MaskSplitterType.HALF`. Default: `HalfSplitType.VERTICAL`.
-
- Returns
- -------
- DirectTransform
- An MRI transformation object.
- """
- logger = logging.getLogger(build_mri_transforms.__name__)
- logger.info("Creating %s MRI transforms.", transforms_type)
-
- if crop and rescale:
- logger.warning(
- "Rescale and crop are both given. Rescale will be applied after cropping. This is not recommended."
- )
-
- if compress_coils and pad_coils:
- logger.warning(
- "Compress coils and pad coils are both given. Compress coils will be applied before padding. "
- "This is not recommended."
- )
-
- mri_transforms = build_supervised_mri_transforms(
- forward_operator=forward_operator,
- backward_operator=backward_operator,
- mask_func=mask_func,
- target_acceleration=target_acceleration,
- crop=crop,
- crop_type=crop_type,
- rescale=rescale,
- rescale_mode=rescale_mode,
- rescale_2d_if_3d=rescale_2d_if_3d,
- pad=pad,
- image_center_crop=image_center_crop,
- random_rotation_degrees=random_rotation_degrees,
- random_rotation_probability=random_rotation_probability,
- random_flip_type=random_flip_type,
- random_flip_probability=random_flip_probability,
- random_reverse_probability=random_reverse_probability,
- padding_eps=padding_eps,
- estimate_sensitivity_maps=estimate_sensitivity_maps,
- sensitivity_maps_type=sensitivity_maps_type,
- estimate_body_coil_image=estimate_body_coil_image,
- sensitivity_maps_gaussian=sensitivity_maps_gaussian,
- sensitivity_maps_espirit_threshold=sensitivity_maps_espirit_threshold,
- sensitivity_maps_espirit_kernel_size=sensitivity_maps_espirit_kernel_size,
- sensitivity_maps_espirit_crop=sensitivity_maps_espirit_crop,
- sensitivity_maps_espirit_max_iters=sensitivity_maps_espirit_max_iters,
- use_acs_as_mask=use_acs_as_mask,
- delete_acs=delete_acs if transforms_type == TransformsType.SUPERVISED else False,
- delete_kspace=delete_kspace if transforms_type == TransformsType.SUPERVISED else False,
- image_recon_type=image_recon_type,
- compress_coils=compress_coils,
- pad_coils=pad_coils,
- scaling_key=scaling_key,
- scale_percentile=scale_percentile,
- registration=registration,
- registration_simulate_reference=registration_simulate_reference,
- registration_simulate_elastic_sigma=registration_simulate_elastic_sigma,
- registration_simulate_elastic_points=registration_simulate_elastic_points,
- registration_simulate_elastic_rotate=registration_simulate_elastic_rotate,
- registration_simulate_elastic_zoom=registration_simulate_elastic_zoom,
- registration_estimate_displacement=registration_estimate_displacement,
- registration_simulate_reference_from_key_index=registration_simulate_reference_from_key_index,
- registration_moving_key=registration_moving_key,
- demons_filter_type=demons_filter_type,
- demons_num_iterations=demons_num_iterations,
- demons_smooth_displacement_field=demons_smooth_displacement_field,
- demons_standard_deviations=demons_standard_deviations,
- demons_intensity_difference_threshold=demons_intensity_difference_threshold,
- demons_maximum_rms_error=demons_maximum_rms_error,
- use_seed=use_seed,
- ).transforms
-
- mri_transforms += [AddBooleanKeysModule(["is_ssl"], [transforms_type != TransformsType.SUPERVISED])]
-
- if transforms_type == TransformsType.SUPERVISED:
- return Compose(mri_transforms)
-
- mask_splitter_kwargs = {
- "ratio": mask_split_ratio,
- "acs_region": mask_split_acs_region,
- "keep_acs": mask_split_keep_acs,
- "use_seed": use_seed,
- "kspace_key": KspaceKey.MASKED_KSPACE,
- }
- mri_transforms += [
- (
- GaussianMaskSplitter(**mask_splitter_kwargs, std_scale=mask_split_gaussian_std)
- if mask_split_type == MaskSplitterType.GAUSSIAN
- else (
- UniformMaskSplitter(**mask_splitter_kwargs)
- if mask_split_type == MaskSplitterType.UNIFORM
- else HalfMaskSplitterModule(
- **{k: v for k, v in mask_splitter_kwargs.items() if k != "ratio"},
- direction=mask_split_half_direction,
- )
- )
- ),
- DeleteKeys([TransformKey.ACS_MASK]),
- ]
-
- mri_transforms += [
- RenameKeys(
- [
- SSLTransformMaskPrefixes.INPUT_ + TransformKey.MASKED_KSPACE,
- SSLTransformMaskPrefixes.TARGET_ + TransformKey.MASKED_KSPACE,
- ],
- ["input_kspace", "kspace"],
- ),
- DeleteKeys(["masked_kspace", "sampling_mask"]),
- ] # Rename keys for SSL engine
-
- mri_transforms += [
- ComputeImage(
- kspace_key=KspaceKey.KSPACE,
- target_key=TransformKey.TARGET,
- backward_operator=backward_operator,
- type_reconstruction=image_recon_type,
- )
- ]
-
- return Compose(mri_transforms)
+# Copyright (c) DIRECT Contributors
+
+"""The `direct.data.mri_transforms` module contains mri transformations utilized to transform or augment k-space data,
+used for DIRECT's training pipeline. They can be also used individually by importing them into python scripts."""
+
+from __future__ import annotations
+
+import contextlib
+import copy
+import functools
+import logging
+import random
+import warnings
+from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union
+
+import numpy as np
+import torch
+
+from direct.algorithms.mri_algorithms import EspiritCalibration
+from direct.data import transforms as T
+from direct.exceptions import ItemNotFoundException
+from direct.registration.elastic_deformation import RandomElasticDeformationModule
+from direct.registration.registration import DemonsFilterType, DisplacementModule, DisplacementTransformType
+from direct.ssl.ssl import (
+ GaussianMaskSplitterModule,
+ HalfMaskSplitterModule,
+ HalfSplitType,
+ MaskSplitterType,
+ SSLTransformMaskPrefixes,
+ UniformMaskSplitterModule,
+)
+from direct.types import DirectEnum, IntegerListOrTupleString, KspaceKey, TransformKey
+from direct.utils import DirectModule, DirectTransform
+from direct.utils.asserts import assert_complex
+
+logger = logging.getLogger(__name__)
+
+
+@contextlib.contextmanager
+def temp_seed(rng, seed):
+ state = rng.get_state()
+ rng.seed(seed)
+ try:
+ yield
+ finally:
+ rng.set_state(state)
+
+
+class Compose(DirectTransform):
+ """Compose several transformations together, for instance ClipAndScale and a flip.
+
+ Code based on torchvision: https://github.com/pytorch/vision, but got forked from there as torchvision has some
+ additional dependencies.
+ """
+
+ def __init__(self, transforms: Iterable[Callable]) -> None:
+ """Inits :class:`Compose`.
+
+ Parameters
+ ----------
+ transforms: Iterable[Callable]
+ List of transforms.
+ """
+ super().__init__()
+ self.transforms = transforms
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Calls :class:`Compose`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dict sample.
+
+ Returns
+ -------
+ dict[str, Any]
+ Dict sample transformed by `transforms`.
+ """
+ for transform in self.transforms:
+ sample = transform(sample)
+
+ return sample
+
+ def __repr__(self):
+ """Representation of :class:`Compose`."""
+ repr_string = self.__class__.__name__ + "("
+ for transform in self.transforms:
+ repr_string += "\n"
+ repr_string += f" {transform},"
+ repr_string = repr_string[:-1] + "\n)"
+ return repr_string
+
+
+class RandomRotation(DirectTransform):
+ r"""Random :math:`k`-space rotation.
+
+ Performs a random rotation with probability :math:`p`. Rotation degrees must be multiples of 90.
+ """
+
+ def __init__(
+ self,
+ degrees: Sequence[int] = (-90, 90),
+ p: float = 0.5,
+ keys_to_rotate: tuple[TransformKey, ...] = (TransformKey.KSPACE,),
+ ) -> None:
+ r"""Inits :class:`RandomRotation`.
+
+ Parameters
+ ----------
+ degrees: sequence of ints
+ Degrees of rotation. Must be a multiple of 90. If len(degrees) > 1, then a degree will be chosen at random.
+ Default: (-90, 90).
+ p: float
+ Probability of rotation. Default: 0.5
+ keys_to_rotate : tuple of TransformKeys
+ Keys to rotate. Default: "kspace".
+ """
+ super().__init__()
+
+ assert all(degree % 90 == 0 for degree in degrees)
+
+ self.degrees = degrees
+ self.p = p
+ self.keys_to_rotate = keys_to_rotate
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Calls :class:`RandomRotation`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dict sample.
+
+ Returns
+ -------
+ dict[str, Any]
+ Sample with rotated values of `keys_to_rotate`.
+ """
+ if random.SystemRandom().random() <= self.p:
+ degree = random.SystemRandom().choice(self.degrees)
+ k = degree // 90
+ for key in self.keys_to_rotate:
+ if key in sample:
+ value = T.view_as_complex(sample[key].clone())
+ sample[key] = T.view_as_real(torch.rot90(value, k=k, dims=(-2, -1)))
+
+ # If rotated by multiples of (n + 1) * 90 degrees, reconstruction size also needs to change
+ reconstruction_size = sample.get("reconstruction_size", None)
+ if reconstruction_size and (k % 2) == 1:
+ sample["reconstruction_size"] = (
+ reconstruction_size[:-3] + reconstruction_size[-3:-1][::-1] + reconstruction_size[-1:]
+ )
+
+ return sample
+
+
+class RandomFlipType(DirectEnum):
+ HORIZONTAL = "horizontal"
+ VERTICAL = "vertical"
+ RANDOM = "random"
+ BOTH = "both"
+
+
+class RandomFlip(DirectTransform):
+ r"""Random k-space flip transform.
+
+ Performs a random flip with probability :math:`p`. Flip can be horizontal, vertical, or a random choice of the two.
+ """
+
+ def __init__(
+ self,
+ flip: RandomFlipType = RandomFlipType.RANDOM,
+ p: float = 0.5,
+ keys_to_flip: tuple[TransformKey, ...] = (TransformKey.KSPACE,),
+ ) -> None:
+ r"""Inits :class:`RandomFlip`.
+
+ Parameters
+ ----------
+ flip : RandomFlipType
+ Horizontal, vertical, or random choice of the two. Default: RandomFlipType.RANDOM.
+ p : float
+ Probability of flip. Default: 0.5
+ keys_to_flip : tuple of TransformKeys
+ Keys to flip. Default: "kspace".
+ """
+ super().__init__()
+
+ self.flip = flip
+ self.p = p
+ self.keys_to_flip = keys_to_flip
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Calls :class:`RandomFlip`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dict sample.
+
+ Returns
+ -------
+ dict[str, Any]
+ Sample with flipped values of `keys_to_flip`.
+ """
+ if random.SystemRandom().random() <= self.p:
+ dims = (
+ (-2,)
+ if self.flip == "horizontal"
+ else (
+ (-1,)
+ if self.flip == "vertical"
+ else (-2, -1) if self.flip == "both" else (random.SystemRandom().choice([-2, -1]),)
+ )
+ )
+
+ for key in self.keys_to_flip:
+ if key in sample:
+ value = T.view_as_complex(sample[key].clone())
+ value = torch.flip(value, dims=dims)
+ sample[key] = T.view_as_real(value)
+
+ return sample
+
+
+class RandomReverse(DirectTransform):
+ r"""Random reverse of the order along a given dimension of a PyTorch tensor."""
+
+ def __init__(
+ self,
+ dim: int = 1,
+ p: float = 0.5,
+ keys_to_reverse: tuple[TransformKey, ...] = (TransformKey.KSPACE,),
+ ) -> None:
+ r"""Inits :class:`RandomReverse`.
+
+ Parameters
+ ----------
+ dim : int
+ Dimension along to perform reversion. Typically, this is for time or slice dimension. Default: 2.
+ p : float
+ Probability of flip. Default: 0.5
+ keys_to_reverse : tuple of TransformKeys
+ Keys to reverse. Default: "kspace".
+ """
+ super().__init__()
+
+ self.dim = dim
+ self.p = p
+ self.keys_to_reverse = keys_to_reverse
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Calls :class:`RandomReverse`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dict sample.
+
+ Returns
+ -------
+ dict[str, Any]
+ Sample with flipped values of `keys_to_flip`.
+ """
+ if random.SystemRandom().random() <= self.p:
+ dim = self.dim
+ for key in self.keys_to_reverse:
+ if key in sample:
+ tensor = sample[key].clone()
+
+ if dim < 0:
+ dim += tensor.dim()
+
+ tensor = T.view_as_complex(tensor)
+
+ index = [slice(None)] * tensor.dim()
+ index[dim] = torch.arange(tensor.size(dim) - 1, -1, -1, dtype=torch.long)
+
+ tensor = tensor[tuple(index)]
+
+ sample[key] = T.view_as_real(tensor)
+
+ return sample
+
+
+class CreateSamplingMask(DirectTransform):
+ """Data Transformer for training MRI reconstruction models.
+
+ Creates sampling mask.
+ """
+
+ def __init__(
+ self,
+ mask_func: Callable,
+ shape: Optional[tuple[int, ...]] = None,
+ use_seed: bool = True,
+ return_acs: bool = False,
+ ) -> None:
+ """Inits :class:`CreateSamplingMask`.
+
+ Parameters
+ ----------
+ mask_func: Callable
+ A function which creates a sampling mask of the appropriate shape.
+ shape: tuple, optional
+ Sampling mask shape. Default: None.
+ use_seed: bool
+ If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
+ the same mask every time. Default: True.
+ return_acs: bool
+ If True, it will generate an ACS mask. Default: False.
+ """
+ super().__init__()
+ self.mask_func = mask_func
+ self.shape = shape
+ self.use_seed = use_seed
+ self.return_acs = return_acs
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Calls :class:`CreateSamplingMask`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dict sample.
+
+ Returns
+ -------
+ dict[str, Any]
+ Sample with `sampling_mask` key.
+ """
+ if not self.shape:
+ shape = sample["kspace"].shape[1:]
+ elif any(_ is None for _ in self.shape): # Allow None as values.
+ kspace_shape = list(sample["kspace"].shape[1:-1])
+ shape = tuple(_ if _ else kspace_shape[idx] for idx, _ in enumerate(self.shape)) + (2,)
+ else:
+ shape = self.shape + (2,)
+
+ seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"])))
+
+ sampling_mask = self.mask_func(shape=shape, seed=seed, return_acs=False)
+
+ if sampling_mask.ndim == 5:
+ acceleration = [
+ np.prod(sampling_mask[0, _].shape) / sampling_mask[0, _].sum() for _ in range(sampling_mask.shape[1])
+ ]
+ sample["acceleration"] = torch.tensor(acceleration, dtype=torch.float32).unsqueeze(0)
+ else:
+ sample["acceleration"] = (np.prod(sampling_mask.shape) / sampling_mask.sum()).unsqueeze(0)
+
+ if "padding" in sample:
+ sampling_mask = T.apply_padding(sampling_mask, sample["padding"])
+
+ # Shape 3D: (1, 1, height, width, 1), 2D: (1, height, width, 1)
+ sample["sampling_mask"] = sampling_mask
+
+ if self.return_acs:
+ sample["acs_mask"] = self.mask_func(shape=shape, seed=seed, return_acs=True)
+ if sampling_mask.ndim == 5:
+ center_fraction = [
+ sample["acs_mask"][0, _].sum() / np.prod(sample["acs_mask"][0, _].shape)
+ for _ in range(sample["acs_mask"].shape[1])
+ ]
+ sample["center_fraction"] = torch.tensor(center_fraction, dtype=torch.float32).unsqueeze(0)
+ else:
+ sample["center_fraction"] = (sample["acs_mask"].sum() / np.prod(sample["acs_mask"].shape)).unsqueeze(0)
+ return sample
+
+
+class ApplyMaskModule(DirectModule):
+ """Data Transformer for training MRI reconstruction models.
+
+ Masks the input k-space (with key `input_kspace_key`) using a sampling mask with key `sampling_mask_key` onto
+ a new masked k-space with key `target_kspace_key`.
+ """
+
+ def __init__(
+ self,
+ sampling_mask_key: str = "sampling_mask",
+ input_kspace_key: KspaceKey = KspaceKey.KSPACE,
+ target_kspace_key: KspaceKey = KspaceKey.MASKED_KSPACE,
+ ) -> None:
+ """Inits :class:`ApplyMaskModule`.
+
+ Parameters
+ ----------
+ sampling_mask_key: str
+ Default: "sampling_mask".
+ input_kspace_key: KspaceKey
+ Default: KspaceKey.KSPACE.
+ target_kspace_key: KspaceKey
+ Default KspaceKey.MASKED_KSPACE.
+ """
+ super().__init__()
+ self.logger = logging.getLogger(type(self).__name__)
+
+ self.sampling_mask_key = sampling_mask_key
+ self.input_kspace_key = input_kspace_key
+ self.target_kspace_key = target_kspace_key
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`ApplyMaskModule`.
+
+ Applies mask with key `sampling_mask_key` onto kspace `input_kspace_key`. Result is stored as a tensor with
+ key `target_kspace_key`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dict sample containing keys `sampling_mask_key` and `input_kspace_key`.
+
+ Returns
+ -------
+ dict[str, Any]
+ Sample with (new) key `target_kspace_key`.
+ """
+ if self.input_kspace_key not in sample:
+ raise ValueError(f"Key {self.input_kspace_key} corresponding to `input_kspace_key` not found in sample.")
+ input_kspace = sample[self.input_kspace_key]
+
+ if self.sampling_mask_key not in sample:
+ raise ValueError(f"Key {self.sampling_mask_key} corresponding to `sampling_mask_key` not found in sample.")
+ sampling_mask = sample[self.sampling_mask_key]
+
+ target_kspace, _ = T.apply_mask(input_kspace, sampling_mask)
+ sample[self.target_kspace_key] = target_kspace
+ return sample
+
+
+class CropKspace(DirectTransform):
+ """Data Transformer for training MRI reconstruction models.
+
+ Crops the k-space by:
+ * It first projects the k-space to the image-domain via the backward operator,
+ * It crops the back-projected k-space to specified shape or key,
+ * It transforms the cropped back-projected k-space to the k-space domain via the forward operator.
+ """
+
+ def __init__(
+ self,
+ crop: Union[str, tuple[int, ...], list[int]],
+ forward_operator: Callable = T.fft2,
+ backward_operator: Callable = T.ifft2,
+ image_space_center_crop: bool = False,
+ random_crop_sampler_type: Optional[str] = "uniform",
+ random_crop_sampler_use_seed: Optional[bool] = True,
+ random_crop_sampler_gaussian_sigma: Optional[list[float]] = None,
+ ) -> None:
+ """Inits :class:`CropKspace`.
+
+ Parameters
+ ----------
+ crop: tuple of ints or str
+ Shape to crop the input to or a string pointing to a crop key (e.g. `reconstruction_size`).
+ forward_operator: Callable
+ The forward operator, e.g. some form of FFT (centered or uncentered).
+ Default: :class:`direct.data.transforms.fft2`.
+ backward_operator: Callable
+ The backward operator, e.g. some form of inverse FFT (centered or uncentered).
+ Default: :class:`direct.data.transforms.ifft2`.
+ image_space_center_crop: bool
+ If set, the crop in the data will be taken in the center
+ random_crop_sampler_type: Optional[str]
+ If "uniform" the random cropping will be done by uniformly sampling `crop`, as opposed to `gaussian` which
+ will sample from a gaussian distribution. If `image_space_center_crop` is True, then this is ignored.
+ Default: "uniform".
+ random_crop_sampler_use_seed: bool
+ If true, a pseudo-random number based on the filename is computed so that every slice of the volume
+ is cropped the same way. Default: True.
+ random_crop_sampler_gaussian_sigma: Optional[list[float]]
+ Standard variance of the gaussian when `random_crop_sampler_type` is `gaussian`.
+ If `image_space_center_crop` is True, then this is ignored. Default: None.
+ """
+ super().__init__()
+ self.logger = logging.getLogger(type(self).__name__)
+
+ self.image_space_center_crop = image_space_center_crop
+
+ if not (isinstance(crop, (Iterable, str))):
+ raise ValueError(
+ f"Invalid input for `crop`. Received {crop}. Can be a list of tuple of integers or a string."
+ )
+ self.crop = crop
+
+ if image_space_center_crop:
+ self.crop_func = T.complex_center_crop
+ else:
+ self.crop_func = functools.partial(
+ T.complex_random_crop,
+ sampler=random_crop_sampler_type,
+ sigma=random_crop_sampler_gaussian_sigma,
+ )
+ self.random_crop_sampler_use_seed = random_crop_sampler_use_seed
+
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Calls :class:`CropKspace`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dict sample containing key `kspace`.
+
+ Returns
+ -------
+ dict[str, Any]
+ Cropped and masked sample.
+ """
+
+ kspace = sample["kspace"] # shape (coil, [slice/time], height, width, complex=2)
+
+ dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D
+
+ backprojected_kspace = self.backward_operator(kspace, dim=dim) # shape (coil, height, width, complex=2)
+
+ if isinstance(self.crop, IntegerListOrTupleString):
+ crop_shape = IntegerListOrTupleString(self.crop)
+ elif isinstance(self.crop, str):
+ assert self.crop in sample, f"Not found {self.crop} key in sample."
+ crop_shape = sample[self.crop][:-1]
+ else:
+ if kspace.ndim == 5 and len(self.crop) == 2:
+ crop_shape = (kspace.shape[1],) + tuple(self.crop)
+ else:
+ crop_shape = tuple(self.crop)
+
+ cropper_args = {
+ "data_list": [backprojected_kspace],
+ "crop_shape": crop_shape,
+ "contiguous": False,
+ }
+ if not self.image_space_center_crop:
+ cropper_args["seed"] = (
+ None if not self.random_crop_sampler_use_seed else tuple(map(ord, str(sample["filename"])))
+ )
+ cropped_backprojected_kspace = self.crop_func(**cropper_args)
+
+ if "sampling_mask" in sample:
+ sample["sampling_mask"] = T.complex_center_crop(
+ sample["sampling_mask"], (1,) + tuple(crop_shape)[1:] if kspace.ndim == 5 else crop_shape
+ )
+ sample["acs_mask"] = T.complex_center_crop(
+ sample["acs_mask"], (1,) + tuple(crop_shape)[1:] if kspace.ndim == 5 else crop_shape
+ )
+
+ # Compute new k-space for the cropped_backprojected_kspace
+ # shape (coil, [slice/time], new_height, new_width, complex=2)
+ sample["kspace"] = self.forward_operator(cropped_backprojected_kspace, dim=dim) # The cropped kspace
+
+ return sample
+
+
+class RescaleMode(DirectEnum):
+ AREA = "area"
+ BICUBIC = "bicubic"
+ BILINEAR = "bilinear"
+ NEAREST = "nearest"
+ NEAREST_EXACT = "nearest-exact"
+ TRILINEAR = "trilinear"
+
+
+class RescaleKspace(DirectTransform):
+ """Rescale k-space (downsample/upsample) module.
+
+ Rescales the k-space:
+ * It first projects the k-space to the image-domain via the backward operator,
+ * It rescales the back-projected k-space to specified shape,
+ * It transforms the rescaled back-projected k-space to the k-space domain via the forward operator.
+
+ Parameters
+ ----------
+ shape : tuple or list of ints
+ Shape to rescale the input. Must be correspond to (height, width).
+ forward_operator : Callable
+ The forward operator, e.g. some form of FFT (centered or uncentered).
+ Default: :class:`direct.data.transforms.fft2`.
+ backward_operator : Callable
+ The backward operator, e.g. some form of inverse FFT (centered or uncentered).
+ Default: :class:`direct.data.transforms.ifft2`.
+ rescale_mode : RescaleMode
+ Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR,
+ RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are
+ supported for 2D or 3D data. Default: RescaleMode.NEAREST.
+ kspace_key : KspaceKey
+ K-space key. Default: KspaceKey.KSPACE.
+ rescale_2d_if_3d : bool, optional
+ If True and input k-space data is 3D, rescaling will be done only on the height and width dimensions.
+ Default: False.
+
+ Note
+ ----
+ If the input k-space data is 3D, rescaling will be done only on the height and width dimensions if
+ `rescale_2d_if_3d` is set to True.
+ """
+
+ def __init__(
+ self,
+ shape: Union[tuple[int, int], list[int]],
+ forward_operator: Callable = T.fft2,
+ backward_operator: Callable = T.ifft2,
+ rescale_mode: RescaleMode = RescaleMode.NEAREST,
+ kspace_key: KspaceKey = KspaceKey.KSPACE,
+ rescale_2d_if_3d: Optional[bool] = None,
+ ) -> None:
+ """Inits :class:`RescaleKspace`.
+
+ Parameters
+ ----------
+ shape : tuple or list of ints
+ Shape to rescale the input. Must be correspond to (height, width).
+ forward_operator : Callable
+ The forward operator, e.g. some form of FFT (centered or uncentered).
+ Default: :class:`direct.data.transforms.fft2`.
+ backward_operator : Callable
+ The backward operator, e.g. some form of inverse FFT (centered or uncentered).
+ Default: :class:`direct.data.transforms.ifft2`.
+ rescale_mode : RescaleMode
+ Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR,
+ RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are
+ supported for 2D or 3D data. Default: RescaleMode.NEAREST.
+ kspace_key : KspaceKey
+ K-space key. Default: KspaceKey.KSPACE.
+ rescale_2d_if_3d : bool, optional
+ If True and input k-space data is 3D, rescaling will be done only on the height and width dimensions,
+ by combining the slice/time dimension with the batch dimension.
+ Default: False.
+ """
+ super().__init__()
+ self.logger = logging.getLogger(type(self).__name__)
+
+ if len(shape) not in [2, 3]:
+ raise ValueError(
+ f"Shape should be a list or tuple of two integers if 2D or three integers if 3D. "
+ f"Received: {shape}."
+ )
+ self.shape = shape
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+ self.rescale_mode = rescale_mode
+ self.kspace_key = kspace_key
+
+ self.rescale_2d_if_3d = rescale_2d_if_3d
+ if rescale_2d_if_3d and len(shape) == 3:
+ raise ValueError("Shape cannot have a length of 3 when rescale_2d_if_3d is set to True.")
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Calls :class:`RescaleKspace`.
+
+ Parameters
+ ----------
+ sample: Dict[str, Any]
+ Dict sample containing key `kspace`.
+
+ Returns
+ -------
+ Dict[str, Any]
+ Cropped and masked sample.
+ """
+ kspace = sample[self.kspace_key] # shape (coil, [slice/time], height, width, complex=2)
+
+ dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D
+
+ backprojected_kspace = self.backward_operator(kspace, dim=dim)
+
+ if kspace.ndim == 5 and self.rescale_2d_if_3d:
+ backprojected_kspace = backprojected_kspace.permute(1, 0, 2, 3, 4)
+
+ if (kspace.ndim == 4) or (kspace.ndim == 5 and not self.rescale_2d_if_3d):
+ backprojected_kspace = backprojected_kspace.unsqueeze(0)
+
+ rescaled_backprojected_kspace = T.complex_image_resize(backprojected_kspace, self.shape, self.rescale_mode)
+
+ if (kspace.ndim == 4) or (kspace.ndim == 5 and not self.rescale_2d_if_3d):
+ rescaled_backprojected_kspace = rescaled_backprojected_kspace.squeeze(0)
+
+ if kspace.ndim == 5 and self.rescale_2d_if_3d:
+ rescaled_backprojected_kspace = rescaled_backprojected_kspace.permute(1, 0, 2, 3, 4)
+
+ # Compute new k-space from rescaled_backprojected_kspace
+ # shape (coil, [slice/time if rescale_2d_if_3d else new_slc_or_time], new_height, new_width, complex=2)
+ sample[self.kspace_key] = self.forward_operator(rescaled_backprojected_kspace, dim=dim) # The rescaled kspace
+
+ return sample
+
+
+class PadKspace(DirectTransform):
+ """Pad k-space with zeros to desired shape module.
+
+ Rescales the k-space by:
+ * It first projects the k-space to the image-domain via the backward operator,
+ * It pads the back-projected k-space to specified shape,
+ * It transforms the rescaled back-projected k-space to the k-space domain via the forward operator.
+
+ Parameters
+ ----------
+ pad_shape : tuple or list of ints
+ Shape to zero-pad the input. Must be correspond to (height, width) or (slice/time, height, width).
+ forward_operator : Callable
+ The forward operator, e.g. some form of FFT (centered or uncentered).
+ Default: :class:`direct.data.transforms.fft2`.
+ backward_operator : Callable
+ The backward operator, e.g. some form of inverse FFT (centered or uncentered).
+ Default: :class:`direct.data.transforms.ifft2`.
+ kspace_key : KspaceKey
+ K-space key. Default: KspaceKey.KSPACE.
+ """
+
+ def __init__(
+ self,
+ pad_shape: Union[tuple[int, ...], list[int]],
+ forward_operator: Callable = T.fft2,
+ backward_operator: Callable = T.ifft2,
+ kspace_key: KspaceKey = KspaceKey.KSPACE,
+ ) -> None:
+ """Inits :class:`RescaleKspace`.
+
+ Parameters
+ ----------
+ pad_shape : tuple or list of ints
+ Shape to zero-pad the input. Must be correspond to (height, width) or (slice/time, height, width).
+ forward_operator : Callable
+ The forward operator, e.g. some form of FFT (centered or uncentered).
+ Default: :class:`direct.data.transforms.fft2`.
+ backward_operator : Callable
+ The backward operator, e.g. some form of inverse FFT (centered or uncentered).
+ Default: :class:`direct.data.transforms.ifft2`.
+ kspace_key : KspaceKey
+ K-space key. Default: KspaceKey.KSPACE.
+ """
+ super().__init__()
+ self.logger = logging.getLogger(type(self).__name__)
+
+ if len(pad_shape) not in [2, 3]:
+ raise ValueError(f"Shape should be a list or tuple of two or three integers. Received: {pad_shape}.")
+
+ self.shape = pad_shape
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+ self.kspace_key = kspace_key
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Calls :class:`PadKspace`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dict sample containing key `kspace`.
+
+ Returns
+ -------
+ dict[str, Any]
+ Cropped and masked sample.
+ """
+ kspace = sample[self.kspace_key] # shape (coil, [slice or time], height, width, complex=2)
+ shape = kspace.shape
+
+ sample["original_size"] = shape[1:-1]
+
+ dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D
+
+ backprojected_kspace = self.backward_operator(kspace, dim=dim)
+ backprojected_kspace = T.view_as_complex(backprojected_kspace)
+
+ padded_backprojected_kspace = T.pad_tensor(backprojected_kspace, self.shape)
+ padded_backprojected_kspace = T.view_as_real(padded_backprojected_kspace)
+
+ # shape (coil, [slice or time], height, width, complex=2)
+ sample[self.kspace_key] = self.forward_operator(padded_backprojected_kspace, dim=dim) # The padded kspace
+
+ return sample
+
+
+class ComputeZeroPadding(DirectTransform):
+ r"""Computes zero padding present in multi-coil kspace input.
+
+ Zero-padding is computed from multi-coil kspace with no signal contribution, i.e. its magnitude
+ is really close to zero:
+
+ .. math ::
+
+ \text{padding} = \sum_{i=1}^{n_c} |y_i| < \frac{1}{n_x \cdot n_y}
+ \sum_{j=1}^{n_x \cdot n_y} \big\{\sum_{i=1}^{n_c} |y_i|\big\} * \epsilon.
+ """
+
+ def __init__(
+ self,
+ kspace_key: KspaceKey = KspaceKey.KSPACE,
+ padding_key: str = "padding",
+ eps: Optional[float] = 0.0001,
+ ) -> None:
+ """Inits :class:`ComputeZeroPadding`.
+
+ Parameters
+ ----------
+ kspace_key: KspaceKey
+ K-space key. Default: KspaceKey.KSPACE.
+ padding_key: str
+ Target key. Default: "padding".
+ eps: float
+ Epsilon to multiply sum of signals. If really high, probably no padding will be produced. Default: 0.0001.
+ """
+ super().__init__()
+ self.kspace_key = kspace_key
+ self.padding_key = padding_key
+ self.eps = eps
+
+ def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]:
+ """Updates sample with a key `padding_key` with value a binary tensor.
+
+ Non-zero entries indicate samples in kspace with key `kspace_key` which have minor contribution, i.e. padding.
+
+ Parameters
+ ----------
+ sample : dict[str, Any]
+ Dict sample containing key `kspace_key`.
+ coil_dim : int
+ Coil dimension. Default: 0.
+
+ Returns
+ -------
+ sample : dict[str, Any]
+ Dict sample containing key `padding_key`.
+ """
+ if self.eps is None:
+ return sample
+ shape = sample[self.kspace_key].shape
+
+ kspace = T.modulus(sample[self.kspace_key].clone()).sum(coil_dim)
+
+ if len(shape) == 5: # Check if 3D data
+ # Assumes that slice dim is 0
+ kspace = kspace.sum(0)
+
+ padding = (kspace < (torch.mean(kspace) * self.eps)).to(kspace.device)
+
+ if len(shape) == 5:
+ padding = padding.unsqueeze(0)
+
+ padding = padding.unsqueeze(coil_dim).unsqueeze(-1)
+ sample[self.padding_key] = padding
+
+ return sample
+
+
+class ApplyZeroPadding(DirectTransform):
+ """Applies zero padding present in multi-coil kspace input."""
+
+ def __init__(self, kspace_key: KspaceKey = KspaceKey.KSPACE, padding_key: str = "padding") -> None:
+ """Inits :class:`ApplyZeroPadding`.
+
+ Parameters
+ ----------
+ kspace_key: KspaceKey
+ K-space key. Default: KspaceKey.KSPACE.
+ padding_key: str
+ Target key. Default: "padding".
+ """
+ super().__init__()
+ self.kspace_key = kspace_key
+ self.padding_key = padding_key
+
+ def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]:
+ """Applies zero padding on `kspace_key` with value a binary tensor.
+
+ Parameters
+ ----------
+ sample : dict[str, Any]
+ Dict sample containing key `kspace_key`.
+ coil_dim : int
+ Coil dimension. Default: 0.
+
+ Returns
+ -------
+ sample : dict[str, Any]
+ Dict sample containing key `padding_key`.
+ """
+
+ sample[self.kspace_key] = T.apply_padding(sample[self.kspace_key], sample[self.padding_key])
+
+ return sample
+
+
+class ReconstructionType(DirectEnum):
+ """Reconstruction method for :class:`ComputeImage` transform."""
+
+ IFFT = "ifft"
+ RSS = "rss"
+ COMPLEX = "complex"
+ COMPLEX_MOD = "complex_mod"
+ SENSE = "sense"
+ SENSE_MOD = "sense_mod"
+
+
+class ComputeImageModule(DirectModule):
+ """Compute Image transform."""
+
+ def __init__(
+ self,
+ kspace_key: KspaceKey,
+ target_key: str,
+ backward_operator: Callable,
+ type_reconstruction: ReconstructionType = ReconstructionType.RSS,
+ ) -> None:
+ """Inits :class:`ComputeImageModule`.
+
+ Parameters
+ ----------
+ kspace_key: KspaceKey
+ K-space key.
+ target_key: str
+ Target key.
+ backward_operator: callable
+ The backward operator, e.g. some form of inverse FFT (centered or uncentered).
+ type_reconstruction: ReconstructionType
+ Type of reconstruction. Can be ReconstructionType.RSS, ReconstructionType.COMPLEX,
+ ReconstructionType.COMPLEX_MOD, ReconstructionType.SENSE, ReconstructionType.SENSE_MOD or
+ ReconstructionType.IFFT. Default: ReconstructionType.RSS.
+ """
+ super().__init__()
+ self.backward_operator = backward_operator
+ self.kspace_key = kspace_key
+ self.target_key = target_key
+ self.type_reconstruction = type_reconstruction
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`ComputeImageModule`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Contains key kspace_key with value a torch.Tensor of shape (coil,\*spatial_dims, complex=2).
+
+ Returns
+ -------
+ sample: dict
+ Contains key target_key with value a torch.Tensor of shape (\*spatial_dims) if `type_reconstruction` is
+ ReconstructionType.RSS, ReconstructionType.COMPLEX_MOD, ReconstructionType.SENSE_MOD,
+ and of shape (\*spatial_dims, complex_dim=2) otherwise.
+ """
+ kspace_data = sample[self.kspace_key]
+ dim = self.spatial_dims.TWO_D if kspace_data.ndim == 5 else self.spatial_dims.THREE_D
+ # Get complex-valued data solution
+ image = self.backward_operator(kspace_data, dim=dim)
+ if self.type_reconstruction == ReconstructionType.IFFT:
+ sample[self.target_key] = image
+ elif self.type_reconstruction in [
+ ReconstructionType.COMPLEX,
+ ReconstructionType.COMPLEX_MOD,
+ ]:
+ sample[self.target_key] = image.sum(self.coil_dim)
+ elif self.type_reconstruction == ReconstructionType.RSS:
+ sample[self.target_key] = T.root_sum_of_squares(image, dim=self.coil_dim)
+ else:
+ if "sensitivity_map" not in sample:
+ raise ItemNotFoundException(
+ "sensitivity map",
+ "Sensitivity map is required for SENSE reconstruction.",
+ )
+ sample[self.target_key] = T.complex_multiplication(T.conjugate(sample["sensitivity_map"]), image).sum(
+ self.coil_dim
+ )
+ if self.type_reconstruction in [
+ ReconstructionType.COMPLEX_MOD,
+ ReconstructionType.SENSE_MOD,
+ ]:
+ sample[self.target_key] = T.modulus(sample[self.target_key], self.complex_dim)
+ return sample
+
+
+class EstimateBodyCoilImage(DirectTransform):
+ """Estimates body coil image."""
+
+ def __init__(self, mask_func: Callable, backward_operator: Callable, use_seed: bool = True) -> None:
+ """Inits :class:`EstimateBodyCoilImage'.
+
+ Parameters
+ ----------
+ mask_func: Callable
+ A function which creates a sampling mask of the appropriate shape.
+ backward_operator: callable
+ The backward operator, e.g. some form of inverse FFT (centered or uncentered).
+ use_seed: bool
+ If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
+ the same mask every time. Default: True.
+ """
+ super().__init__()
+ self.mask_func = mask_func
+ self.use_seed = use_seed
+ self.backward_operator = backward_operator
+
+ def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]:
+ """Calls :class:`EstimateBodyCoilImage`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Contains key kspace_key with value a torch.Tensor of shape (coil, ..., complex=2).
+ coil_dim: int
+ Coil dimension. Default: 0.
+
+ Returns
+ ----------
+ sample: dict[str, Any]
+ Contains key `"body_coil_image`.
+ """
+ kspace = sample["kspace"]
+
+ # We need to create an ACS mask based on the shape of this kspace, as it can be cropped.
+ seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"])))
+ kspace_shape = tuple(sample["kspace"].shape[-3:])
+ acs_mask = self.mask_func(shape=kspace_shape, seed=seed, return_acs=True)
+
+ kspace = acs_mask * kspace + 0.0
+ dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D
+ acs_image = self.backward_operator(kspace, dim=dim)
+
+ sample["body_coil_image"] = T.root_sum_of_squares(acs_image, dim=coil_dim)
+ return sample
+
+
+class SensitivityMapType(DirectEnum):
+ ESPIRIT = "espirit"
+ RSS_ESTIMATE = "rss_estimate"
+ UNIT = "unit"
+
+
+class EstimateSensitivityMapModule(DirectModule):
+ """Data Transformer for training MRI reconstruction models.
+
+ Estimates sensitivity maps given masked k-space data using one of three methods:
+
+ * Unit: unit sensitivity map in case of single coil acquisition.
+ * RSS-estimate: sensitivity maps estimated by using the root-sum-of-squares of the autocalibration-signal.
+ * ESPIRIT: sensitivity maps estimated with the ESPIRIT method [1]_. Note that this is currently not
+ implemented for 3D data, and attempting to use it in such cases will result in a NotImplementedError.
+
+ References
+ ----------
+
+ .. [1] Uecker M, Lai P, Murphy MJ, Virtue P, Elad M, Pauly JM, Vasanawala SS, Lustig M. ESPIRiT--an eigenvalue
+ approach to autocalibrating parallel MRI: where SENSE meets GRAPPA. Magn Reson Med. 2014 Mar;71(3):990-1001.
+ doi: 10.1002/mrm.24751. PMID: 23649942; PMCID: PMC4142121.
+ """
+
+ def __init__(
+ self,
+ kspace_key: KspaceKey = KspaceKey.ACS_KSPACE,
+ backward_operator: Callable = T.ifft2,
+ type_of_map: Optional[SensitivityMapType] = SensitivityMapType.RSS_ESTIMATE,
+ gaussian_sigma: Optional[float] = None,
+ espirit_threshold: Optional[float] = 0.05,
+ espirit_kernel_size: Optional[int] = 6,
+ espirit_crop: Optional[float] = 0.95,
+ espirit_max_iters: Optional[int] = 30,
+ ) -> None:
+ """Inits :class:`EstimateSensitivityMapModule`.
+
+ Parameters
+ ----------
+ kspace_key: KspaceKey
+ K-space key to compute the ACS image from. If `kspace_key` is not `KspaceKey.ACS_KSPACE`,
+ the ACS mask should be provided in the sample. Default: KspaceKey.ACS_KSPACE.
+ backward_operator: callable
+ The backward operator, e.g. some form of inverse FFT (centered or uncentered).
+ type_of_map: SensitivityMapType, optional
+ Type of map to estimate. Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or
+ SensitivityMapType.ESPIRIT. Default: SensitivityMapType.RSS_ESTIMATE.
+ gaussian_sigma: float, optional
+ If non-zero, acs_image well be calculated
+ espirit_threshold: float, optional
+ Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
+ Default: 0.05.
+ espirit_kernel_size: int, optional
+ Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
+ Default: 6.
+ espirit_crop: float, optional
+ Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
+ Default: 0.95.
+ espirit_max_iters: int, optional
+ Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30.
+ """
+ super().__init__()
+ self.backward_operator = backward_operator
+ self.kspace_key = kspace_key
+ self.type_of_map = type_of_map
+
+ # RSS estimate attributes
+ self.gaussian_sigma = gaussian_sigma
+ # Espirit attributes
+ if type_of_map == SensitivityMapType.ESPIRIT:
+ self.espirit_calibrator = EspiritCalibration(
+ backward_operator,
+ espirit_threshold,
+ espirit_kernel_size,
+ espirit_crop,
+ espirit_max_iters,
+ kspace_key,
+ )
+ self.espirit_threshold = espirit_threshold
+ self.espirit_kernel_size = espirit_kernel_size
+ self.espirit_crop = espirit_crop
+ self.espirit_max_iters = espirit_max_iters
+
+ def estimate_acs_image(self, sample: dict[str, Any], width_dim: int = -2) -> torch.Tensor:
+ """Estimates the autocalibration (ACS) image by sampling the k-space using the ACS mask.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Sample dictionary,
+ width_dim: int
+ Dimension corresponding to width. Default: -2.
+
+ Returns
+ -------
+ acs_image: torch.Tensor
+ Estimate of the ACS image.
+ """
+ kspace_data = sample[self.kspace_key]
+
+ if self.kspace_key != KspaceKey.ACS_KSPACE:
+ if TransformKey.ACS_MASK not in sample:
+ raise ValueError("ACS mask is required for estimating ACS image from k-space but not found.")
+ kspace_data = kspace_data * sample[TransformKey.ACS_MASK]
+
+ if self.gaussian_sigma == 0 or not self.gaussian_sigma:
+ kspace_acs = kspace_data + 0.0 # + 0.0 removes the sign of zeros.
+ else:
+ gaussian_mask = torch.linspace(-1, 1, kspace_data.size(width_dim), dtype=kspace_data.dtype)
+ gaussian_mask = torch.exp(-((gaussian_mask / self.gaussian_sigma) ** 2))
+ gaussian_mask_shape = torch.ones(len(kspace_data.shape)).int()
+ gaussian_mask_shape[width_dim] = kspace_data.size(width_dim)
+ gaussian_mask = gaussian_mask.reshape(tuple(gaussian_mask_shape))
+ kspace_acs = kspace_data * gaussian_mask + 0.0
+
+ # Get complex-valued data solution
+ # Shape (batch, [slice/time], coil, height, width, complex=2)
+ dim = self.spatial_dims.TWO_D if kspace_data.ndim == 5 else self.spatial_dims.THREE_D
+ acs_image = self.backward_operator(kspace_acs, dim=dim)
+
+ return acs_image
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Calculates sensitivity maps for the input sample.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Must contain key matching kspace_key with value a (complex) torch.Tensor
+ of shape (coil, height, width, complex=2).
+
+ Returns
+ -------
+ sample: dict[str, Any]
+ Sample with key "sensitivity_map" with value the estimated sensitivity map.
+ """
+ kspace = sample[self.kspace_key] # shape (batch, coil, [slice/time], height, width, complex=2)
+
+ if kspace.shape[self.coil_dim] == 1:
+ warnings.warn(
+ "Estimation of sensitivity map of Single-coil data. This warning will be displayed only once."
+ )
+ if "sensitivity_map" in sample:
+ warnings.warn(
+ "`sensitivity_map` is given, but will be overwritten. This warning will be displayed only once."
+ )
+
+ if self.type_of_map == SensitivityMapType.UNIT:
+ sensitivity_map = torch.zeros(kspace.shape).float()
+ # Assumes complex channel is last
+ assert_complex(kspace, complex_last=True)
+ sensitivity_map[..., 0] = 1.0
+ # Shape (batch, coil, [slice/time], height, width, complex=2)
+ sensitivity_map = sensitivity_map.to(kspace.device)
+
+ elif self.type_of_map == SensitivityMapType.RSS_ESTIMATE:
+ # Shape (batch, coil, [slice/time], height, width, complex=2)
+ acs_image = self.estimate_acs_image(sample)
+ # Shape (batch, [slice/time], height, width)
+ acs_image_rss = T.root_sum_of_squares(acs_image, dim=self.coil_dim)
+ # Shape (batch, 1, [slice/time], height, width, 1)
+ acs_image_rss = acs_image_rss.unsqueeze(self.coil_dim).unsqueeze(self.complex_dim)
+ # Shape (batch, coil, [slice/time], height, width, complex=2)
+ sensitivity_map = T.safe_divide(acs_image, acs_image_rss)
+ else:
+ if sample[self.kspace_key].ndim > 5:
+ raise NotImplementedError(
+ "EstimateSensitivityMapModule is not yet implemented for "
+ "Espirit sensitivity map estimation for 3D data."
+ )
+ sensitivity_map = self.espirit_calibrator(sample)
+
+ sensitivity_map_norm = torch.sqrt(
+ (sensitivity_map**2).sum(self.complex_dim).sum(self.coil_dim)
+ ) # shape (batch, [slice/time], height, width)
+ sensitivity_map_norm = sensitivity_map_norm.unsqueeze(self.coil_dim).unsqueeze(self.complex_dim)
+
+ sample[TransformKey.SENSITIVITY_MAP] = T.safe_divide(sensitivity_map, sensitivity_map_norm)
+ return sample
+
+
+class AddBooleanKeysModule(DirectModule):
+ """Adds keys with boolean values to sample."""
+
+ def __init__(self, keys: list[str], values: list[bool]) -> None:
+ """Inits :class:`AddBooleanKeysModule`.
+
+ Parameters
+ ----------
+ keys : list[str]
+ A list of keys to be added.
+ values : list[bool]
+ A list of values corresponding to the keys.
+ """
+ super().__init__()
+ self.keys = keys
+ self.values = values
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Adds boolean keys to the input sample dictionary.
+
+ Parameters
+ ----------
+ sample : dict[str, Any]
+ The input sample dictionary.
+
+ Returns
+ -------
+ dict[str, Any]
+ The modified sample with added boolean keys.
+ """
+ for key, value in zip(self.keys, self.values):
+ sample[key] = value
+
+ return sample
+
+
+class CopyKeysModule(DirectModule):
+ """Copy keys to a new name from the sample if present."""
+
+ def __init__(self, keys: list[str], new_keys: list[str]) -> None:
+ """Inits :class:`CopyKeysModule`.
+
+ Parameters
+ ----------
+ keys: List[str]
+ Key(s) to copy.
+ new_keys: List[str]
+ Key(s) to create.
+ """
+ super().__init__()
+ self.keys = keys
+ self.new_keys = new_keys
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`CopyKeysModule`.
+
+ Parameters
+ ----------
+ sample: Dict[str, Any]
+ Dictionary to look for keys and copy them with a new name.
+
+ Returns
+ -------
+ Dict[str, Any]
+ Dictionary with copied specified keys.
+ """
+ for key, new_key in zip(self.keys, self.new_keys):
+ if key in sample:
+ if isinstance(sample[key], np.ndarray):
+ sample[new_key] = sample[key].copy() # Copy NumPy array
+ elif isinstance(sample[key], torch.Tensor):
+ sample[new_key] = sample[key].detach().clone() # Copy Torch tensor
+ else:
+ sample[new_key] = copy.deepcopy(sample[key])
+ return sample
+
+
+class CompressCoilModule(DirectModule):
+ """Compresses k-space coils using SVD."""
+
+ def __init__(self, kspace_key: KspaceKey, num_coils: int) -> None:
+ """Inits :class:`CompressCoilModule`.
+
+ Parameters
+ ----------
+ kspace_key : KspaceKey
+ K-space key.
+ num_coils : int
+ Number of coils to compress.
+ """
+ super().__init__()
+ self.kspace_key = kspace_key
+ self.num_coils = num_coils
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Performs coil compression to input k-space.
+
+ Parameters
+ ----------
+ sample : dict[str, Any]
+ Dict sample containing key `kspace_key`. Assumes coil dimension is first axis.
+
+ Returns
+ -------
+ sample : dict[str, Any]
+ Dict sample with `kspace_key` compressed to num_coils.
+ """
+ k_space = sample[self.kspace_key].clone() # shape (batch, coil, [slice/time], height, width, complex=2)
+
+ if k_space.shape[1] <= self.num_coils:
+ return sample
+
+ ndim = k_space.ndim
+
+ k_space = torch.view_as_complex(k_space)
+
+ if ndim == 6: # If 3D sample reshape slice into batch dimension as sensitivities are computed 2D
+ num_slice_or_time = k_space.shape[2]
+ k_space = k_space.permute(0, 2, 1, 3, 4)
+ k_space = k_space.reshape(k_space.shape[0] * num_slice_or_time, *k_space.shape[2:])
+
+ shape = k_space.shape
+
+ # Reshape the k-space data to combine spatial dimensions
+ k_space_reshaped = k_space.reshape(shape[0], shape[1], -1)
+
+ # Compute the coil combination matrix using Singular Value Decomposition (SVD)
+ U, _, _ = torch.linalg.svd(k_space_reshaped, full_matrices=False)
+
+ # Select the top ncoils_new singular vectors from the decomposition
+ U_new = U[:, :, : self.num_coils]
+
+ # Perform coil compression
+ compressed_k_space = torch.matmul(U_new.transpose(1, 2), k_space_reshaped)
+
+ # Reshape the compressed k-space back to its original shape
+ compressed_k_space = compressed_k_space.reshape(shape[0], self.num_coils, *shape[2:])
+
+ if ndim == 6:
+ compressed_k_space = compressed_k_space.reshape(
+ shape[0] // num_slice_or_time, num_slice_or_time, self.num_coils, *shape[2:]
+ ).permute(0, 2, 1, 3, 4)
+
+ compressed_k_space = torch.view_as_real(compressed_k_space)
+ sample[self.kspace_key] = compressed_k_space # shape (batch, new coil, [slice/time], height, width, complex=2)
+
+ return sample
+
+
+class DeleteKeysModule(DirectModule):
+ """Remove keys from the sample if present."""
+
+ def __init__(self, keys: list[str]) -> None:
+ """Inits :class:`DeleteKeys`.
+
+ Parameters
+ ----------
+ keys: list[str]
+ Key(s) to delete.
+ """
+ super().__init__()
+ self.keys = keys
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`DeleteKeys`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dictionary to look for keys and remove them.
+
+ Returns
+ -------
+ dict[str, Any]
+ Dictionary with deleted specified keys.
+ """
+ for key in self.keys:
+ if key in sample:
+ del sample[key]
+
+ return sample
+
+
+class IndexSelectionMode(DirectEnum):
+ RANDOM = "random"
+ CUSTOM = "custom"
+ RANGE = "range"
+
+
+class IndexSelectionModule(DirectModule):
+ """Randomly selects indices from the sample.
+
+ Parameters
+ ----------
+ key: TransformKey
+ Key to select indices from.
+ mode: IndexSelectionMode
+ Mode of index selection. Can be IndexSelectionMode.RANDOM, IndexSelectionMode.CUSTOM or
+ IndexSelectionMode.RANGE. Default: IndexSelectionMode.CUSTOM.
+ num_indices: int
+ Number of indices to select.
+ out_key: TransformKey, optional
+ Key to store the selected indices. If None, the indices are stored in the same key.
+ Default: None.
+ index_dim: int
+ Dimension along which to select indices. Default: 1.
+ use_seed: bool
+ If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
+ the same mask every time. Default: True
+ """
+
+ def __init__(
+ self,
+ key: TransformKey,
+ mode: IndexSelectionMode = IndexSelectionMode.CUSTOM,
+ indices: Optional[list[int]] = None,
+ num_indices: Optional[int] = None,
+ out_key: Optional[TransformKey] = None,
+ index_dim: int = 0,
+ use_seed: bool = True,
+ ) -> None:
+ """Inits :class:`IndexSelection`.
+
+ Parameters
+ ----------
+ key: TransformKey
+ Key to select indices from.
+ mode: IndexSelectionMode
+ Mode of index selection. Can be IndexSelectionMode.RANDOM, IndexSelectionMode.CUSTOM or
+ IndexSelectionMode.RANGE. Default: IndexSelectionMode.CUSTOM.
+ indices: list[int], optional
+ List of indices to select if mode is IndexSelectionMode.CUSTOM or range if mode is
+ IndexSelectionMode.RANGE. Default: None.
+ num_indices: int
+ Number of indices to select if mode is IndexSelectionMode.RANDOM. Default: None.
+ out_key: TransformKey, optional
+ Key to store the selected indices. If None, the indices are stored in the same key.
+ Default: None.
+ index_dim: int
+ Dimension along which to select indices. Default: 1.
+ use_seed: bool
+ If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
+ the same mask every time. Default: True
+ """
+ super().__init__()
+ self.key = key
+ self.out_key = out_key if out_key is not None else key
+ self.mode = mode
+ self.indices = indices
+ self.num_indices = num_indices
+ self.index_dim = index_dim
+ self.use_seed = use_seed
+ self.rng = np.random.RandomState()
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`IndexSelection`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dictionary to look for key and select indices from.
+
+ Returns
+ -------
+ dict[str, Any]
+ Dictionary with randomly selected indices.
+ """
+ if self.key not in sample:
+ return sample
+
+ if self.mode == IndexSelectionMode.RANDOM:
+ seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"])))
+ with temp_seed(self.rng, seed):
+ num_to_keep = max(min(self.num_indices, sample[self.key].shape[self.index_dim]), 1)
+ start = self.rng.randint(0, sample[self.key].shape[self.index_dim] - num_to_keep)
+ keep_indices = torch.arange(start, start + num_to_keep, device=sample[self.key].device)
+ else:
+ if self.mode == IndexSelectionMode.CUSTOM:
+ keep_indices = torch.tensor(
+ [idx for idx in self.indices if np.abs(idx) < sample[self.key].shape[self.index_dim]],
+ device=sample[self.key].device,
+ )
+ else:
+ keep_indices = torch.arange(self.indices[0], self.indices[1], device=sample[self.key].device)
+ num_to_keep = len(keep_indices)
+
+ sample[self.out_key] = sample[self.key].index_select(self.index_dim, keep_indices)
+
+ if num_to_keep == 1:
+ sample[self.out_key] = sample[self.out_key].squeeze(self.index_dim)
+
+ return sample
+
+
+class DropIndexModule(DirectModule):
+ """Drop indices from the sample.
+
+ Parameters
+ ----------
+ keys: list[TransformKey]
+ Key(s) to drop indices from.
+ index: int
+ Index to drop.
+ index_dim: int, list[int]
+ Dimension(s) along which to drop indices. If a list, must have the same length as `keys`. Default: 1.
+ store_deleted_keys: list[TransformKey], optional
+ Key(s) to store the deleted indices. If None, the deleted indices are not stored. If the length does not
+ match `keys`, the remaining keys are set to None. Default: None.
+ """
+
+ def __init__(
+ self,
+ keys: list[TransformKey],
+ index: int,
+ index_dim: int | list[int] = 1,
+ store_deleted_keys: Optional[list[TransformKey]] = None,
+ ) -> None:
+ """Inits :class:`DropIndexModule`.
+
+ Parameters
+ ----------
+ keys: list[TransformKey]
+ Key(s) to drop indices from.
+ index: int
+ Index to drop.
+ index_dim: int, list[int]
+ Dimension(s) along which to drop indices. If a list, must have the same length as `keys`. Default: 1.
+ store_deleted_keys: list[TransformKey], optional
+ Key(s) to store the deleted indices. If None, the deleted indices are not stored. If the length does not
+ match `keys`, the remaining keys are set to None. Default: None.
+ """
+ super().__init__()
+ self.keys = keys
+ self.index = index
+ self.index_dim = [index_dim] * len(keys) if isinstance(index_dim, int) else index_dim
+ self.store_deleted_keys = store_deleted_keys
+ if self.store_deleted_keys is not None and len(keys) > len(self.store_deleted_keys):
+ self.store_deleted_keys = store_deleted_keys + [None] * (len(keys) - len(self.store_deleted_keys))
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`DropIndexModule`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dictionary to look for key and drop indices from.
+
+ Returns
+ -------
+ dict[str, Any]
+ Dictionary with dropped index.
+ """
+
+ for i, key in enumerate(self.keys):
+ if key not in sample:
+ continue
+ # This might be helpful, for instance, in case a single mask is used for all time frames
+ if sample[key].shape[self.index_dim[i]] == 1:
+ continue
+ if self.store_deleted_keys is not None:
+ deleted_key = self.store_deleted_keys[i]
+ if deleted_key:
+ sample[deleted_key] = sample[key].index_select(
+ self.index_dim[i],
+ torch.tensor(
+ [idx for idx in range(sample[key].shape[self.index_dim[i]]) if idx == self.index],
+ device=sample[key].device,
+ ),
+ )
+ sample[key] = sample[key].index_select(
+ self.index_dim[i],
+ torch.tensor(
+ [idx for idx in range(sample[key].shape[self.index_dim[i]]) if idx != self.index],
+ device=sample[key].device,
+ ),
+ )
+
+ return sample
+
+
+class SqueezeKeyModule(DirectModule):
+ """Squeeze the specified key(s) in the sample.
+
+ Parameters
+ ----------
+ keys: TransformKey
+ Key(s) to squeeze.
+ dim: int
+ Dimension to squeeze.
+ """
+
+ def __init__(self, keys: TransformKey, dim: int) -> None:
+ """Inits :class:`SqueezeKeyModule`.
+
+ Parameters
+ ----------
+ keys: TransformKey
+ Key(s) to squeeze.
+ dim: int
+ Dimension to squeeze.
+ """
+ super().__init__()
+ self.keys = keys
+ self.dim = dim
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`SqueezeKeyModule`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dictionary to look for keys to squeeze.
+
+ Returns
+ -------
+ dict[str, Any]
+ Dictionary with squeezed specified keys.
+ """
+ for key in self.keys:
+ if key in sample:
+ sample[key] = sample[key].squeeze(self.dim)
+ return sample
+
+
+class RenameKeysModule(DirectModule):
+ """Rename keys from the sample if present."""
+
+ def __init__(self, old_keys: list[str], new_keys: list[str]) -> None:
+ """Inits :class:`RenameKeys`.
+
+ Parameters
+ ----------
+ old_keys: list[str]
+ Key(s) to rename.
+ new_keys: list[str]
+ Key(s) to replace old keys.
+ """
+ super().__init__()
+ self.old_keys = old_keys
+ self.new_keys = new_keys
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`RenameKeys`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dictionary to look for keys and rename them.
+
+ Returns
+ -------
+ dict[str, Any]
+ Dictionary with renamed specified keys.
+ """
+ for old_key, new_key in zip(self.old_keys, self.new_keys):
+ if old_key in sample:
+ sample[new_key] = sample.pop(old_key)
+
+ return sample
+
+
+class PadCoilDimensionModule(DirectModule):
+ """Pad the coils by zeros to a given number of coils.
+
+ Useful if you want to collate volumes with different coil dimension.
+ """
+
+ def __init__(
+ self,
+ pad_coils: Optional[int] = None,
+ key: str = "masked_kspace",
+ coil_dim: int = 1,
+ ) -> None:
+ """Inits :class:`PadCoilDimensionModule`.
+
+ Parameters
+ ----------
+ pad_coils: int, optional
+ Number of coils to pad to. Default: None.
+ key: str
+ Key to pad in sample. Default: "masked_kspace".
+ coil_dim: int
+ Coil dimension along which the pad will be done. Default: 0.
+ """
+ super().__init__()
+ self.num_coils = pad_coils
+ self.key = key
+ self.coil_dim = coil_dim
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`PadCoilDimensionModule`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Dictionary with key `self.key`.
+
+ Returns
+ -------
+ sample: dict[str, Any]
+ Dictionary with padded coils of sample[self.key] if self.num_coils is not None.
+ """
+ if not self.num_coils:
+ return sample
+
+ if self.key not in sample:
+ return sample
+
+ data = sample[self.key]
+
+ curr_num_coils = data.shape[self.coil_dim]
+ if curr_num_coils > self.num_coils:
+ raise ValueError(
+ f"Tried to pad to {self.num_coils} coils, but already have {curr_num_coils} for "
+ f"{sample['filename']}."
+ )
+ if curr_num_coils == self.num_coils:
+ return sample
+
+ shape = data.shape
+ num_coils = shape[self.coil_dim]
+ padding_data_shape = list(shape).copy()
+ padding_data_shape[self.coil_dim] = max(self.num_coils - num_coils, 0)
+ zeros = torch.zeros(padding_data_shape, dtype=data.dtype, device=data.device)
+ sample[self.key] = torch.cat([zeros, data], dim=self.coil_dim)
+
+ return sample
+
+
+class ComputeScalingFactorModule(DirectModule):
+ """Calculates scaling factor.
+
+ Scaling factor is for the input data based on either to the percentile or to the maximum of `normalize_key`.
+ """
+
+ def __init__(
+ self,
+ normalize_key: Union[None, TransformKey] = TransformKey.MASKED_KSPACE,
+ percentile: Union[None, float] = 0.99,
+ scaling_factor_key: TransformKey = TransformKey.SCALING_FACTOR,
+ ) -> None:
+ """Inits :class:`ComputeScalingFactorModule`.
+
+ Parameters
+ ----------
+ normalize_key : TransformKey or None
+ Key name to compute the data for. If the maximum has to be computed on the ACS, ensure the reconstruction
+ on the ACS is available (typically `body_coil_image`). Default: "masked_kspace".
+ percentile : float or None
+ Rescale data with the given percentile. If None, the division is done by the maximum. Default: 0.99.
+ scaling_factor_key : TransformKey
+ Name of how the scaling factor will be stored. Default: "scaling_factor".
+ """
+ super().__init__()
+ self.normalize_key = normalize_key
+ self.percentile = percentile
+ self.scaling_factor_key = scaling_factor_key
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`ComputeScalingFactorModule`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Sample with key `normalize_key` to compute scaling_factor.
+
+ Returns
+ -------
+ sample: dict[str, Any]
+ Sample with key `scaling_factor_key`.
+ """
+ if self.normalize_key == "scaling_factor": # This is a real-valued given number
+ scaling_factor = sample["scaling_factor"]
+ elif not self.normalize_key:
+ kspace = sample["masked_kspace"]
+ scaling_factor = torch.tensor([1.0] * kspace.size(0), device=kspace.device, dtype=kspace.dtype)
+ else:
+ data = sample[self.normalize_key]
+ scaling_factor: Union[list, torch.Tensor] = []
+ # Compute the maximum and scale the input
+ if self.percentile:
+ for _ in range(data.size(0)):
+ # Used in case the k-space is padded (e.g. for batches)
+ non_padded_coil_data = data[_][data[_].sum(dim=tuple(range(1, data[_].ndim))).bool()]
+ tview = -1.0 * T.modulus(non_padded_coil_data).view(-1)
+ s, _ = torch.kthvalue(tview, int((1 - self.percentile) * tview.size()[0]) + 1)
+ scaling_factor += [-1.0 * s]
+ scaling_factor = torch.tensor(scaling_factor, dtype=data.dtype, device=data.device)
+ else:
+ scaling_factor = T.modulus(data).amax(dim=list(range(data.ndim))[1:-1])
+ sample[self.scaling_factor_key] = scaling_factor
+ return sample
+
+
+class NormalizeModule(DirectModule):
+ """Normalize the input data."""
+
+ def __init__(
+ self,
+ scaling_factor_key: TransformKey = TransformKey.SCALING_FACTOR,
+ keys_to_normalize: Optional[list[TransformKey]] = None,
+ ) -> None:
+ """Inits :class:`NormalizeModule`.
+
+ Parameters
+ ----------
+ scaling_factor_key : TransformKey
+ Name of scaling factor key expected in sample. Default: 'scaling_factor'.
+ """
+ super().__init__()
+ self.scaling_factor_key = scaling_factor_key
+
+ self.keys_to_normalize = (
+ [
+ "masked_kspace",
+ "target",
+ "kspace",
+ "body_coil_image", # sensitivity_map does not require normalization.
+ "initial_image",
+ "initial_kspace",
+ ]
+ if keys_to_normalize is None
+ else keys_to_normalize
+ )
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`NormalizeModule`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Sample to normalize.
+
+ Returns
+ -------
+ sample: dict[str, Any]
+ Sample with normalized values if their respective key is in `keys_to_normalize` and key
+ `scaling_factor_key` exists in sample.
+ """
+ scaling_factor = sample.get(self.scaling_factor_key, None)
+ # Normalize data
+ if scaling_factor is not None:
+ for key in sample.keys():
+ if key not in self.keys_to_normalize:
+ continue
+ sample[key] = T.safe_divide(
+ sample[key],
+ scaling_factor.reshape(-1, *[1 for _ in range(sample[key].ndim - 1)]),
+ )
+
+ sample["scaling_diff"] = 0.0
+ return sample
+
+
+class WhitenDataModule(DirectModule):
+ """Whitens complex data Module."""
+
+ def __init__(self, epsilon: float = 1e-10, key: str = "complex_image") -> None:
+ """Inits :class:`WhitenDataModule`.
+
+ Parameters
+ ----------
+ epsilon: float
+ Default: 1e-10.
+ key: str
+ Key to whiten. Default: "complex_image".
+ """
+ super().__init__()
+ self.epsilon = epsilon
+ self.key = key
+
+ def complex_whiten(self, complex_image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Whiten complex image.
+
+ Parameters
+ ----------
+ complex_image: torch.Tensor
+ Complex image tensor to whiten.
+
+ Returns
+ -------
+ mean, std, whitened_image: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ """
+ # From: https://github.com/facebookresearch/fastMRI
+ # blob/da1528585061dfbe2e91ebbe99a5d4841a5c3f43/banding_removal/fastmri/data/transforms.py#L464 # noqa
+ real = complex_image[..., 0]
+ imag = complex_image[..., 1]
+
+ # Center around mean.
+ mean = complex_image.mean()
+ centered_complex_image = complex_image - mean
+
+ # Determine covariance between real and imaginary.
+ n_elements = real.nelement()
+ real_real = (real.mul(real).sum() - real.mean().mul(real.mean())) / n_elements
+ real_imag = (real.mul(imag).sum() - real.mean().mul(imag.mean())) / n_elements
+ imag_imag = (imag.mul(imag).sum() - imag.mean().mul(imag.mean())) / n_elements
+ eig_input = torch.Tensor([[real_real, real_imag], [real_imag, imag_imag]])
+
+ # Remove correlation by rotating around covariance eigenvectors.
+ eig_values, eig_vecs = torch.linalg.eig(eig_input)
+
+ # Scale by eigenvalues for unit variance.
+ std = (eig_values.real + self.epsilon).sqrt()
+ whitened_image = torch.matmul(centered_complex_image, eig_vecs.real) / std
+
+ return mean, std, whitened_image
+
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Forward pass of :class:`WhitenDataModule`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Sample with key `key`.
+
+ Returns
+ -------
+ sample: dict[str, Any]
+ Sample with value of `key` whitened.
+ """
+ _, _, whitened_image = self.complex_whiten(sample[self.key])
+ sample[self.key] = whitened_image
+ return sample
+
+
+class AddTargetAcceleration(DirectTransform):
+ """This will replace the acceleration factor in the sample with the target acceleration factor."""
+
+ def __init__(self, target_acceleration: float):
+ super().__init__()
+ self.target_acceleration = target_acceleration
+
+ def __call__(self, sample: dict[str, Any]):
+ sample["acceleration"][:] = self.target_acceleration
+ return sample
+
+
+class ModuleWrapper:
+ class SubWrapper:
+ def __init__(self, transform: Callable, toggle_dims: bool) -> None:
+ self.toggle_dims = toggle_dims
+ self._transform = transform
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ if self.toggle_dims:
+ for k, v in sample.items():
+ if isinstance(v, (torch.Tensor, np.ndarray)):
+ sample[k] = v[None]
+ else:
+ sample[k] = [v]
+
+ sample = self._transform.forward(sample)
+
+ if self.toggle_dims:
+ for k, v in sample.items():
+ if isinstance(v, (torch.Tensor, np.ndarray)):
+ sample[k] = v.squeeze(0)
+ else:
+ sample[k] = v[0]
+
+ return sample
+
+ def __repr__(self) -> str:
+ return self._transform.__repr__()
+
+ def __init__(self, module: Callable, toggle_dims: bool) -> None:
+ self._module = module
+ self.toggle_dims = toggle_dims
+
+ def __call__(self, *args, **kwargs) -> SubWrapper:
+ return self.SubWrapper(self._module(*args, **kwargs), toggle_dims=self.toggle_dims)
+
+
+ApplyMask = ModuleWrapper(ApplyMaskModule, toggle_dims=False)
+ComputeImage = ModuleWrapper(ComputeImageModule, toggle_dims=True)
+EstimateSensitivityMap = ModuleWrapper(EstimateSensitivityMapModule, toggle_dims=True)
+CopyKeys = ModuleWrapper(CopyKeysModule, toggle_dims=False)
+DeleteKeys = ModuleWrapper(DeleteKeysModule, toggle_dims=False)
+RenameKeys = ModuleWrapper(RenameKeysModule, toggle_dims=False)
+IndexSelection = ModuleWrapper(IndexSelectionModule, toggle_dims=False)
+DropIndex = ModuleWrapper(DropIndexModule, toggle_dims=False)
+SqueezeKey = ModuleWrapper(SqueezeKeyModule, toggle_dims=False)
+CompressCoil = ModuleWrapper(CompressCoilModule, toggle_dims=True)
+PadCoilDimension = ModuleWrapper(PadCoilDimensionModule, toggle_dims=True)
+ComputeScalingFactor = ModuleWrapper(ComputeScalingFactorModule, toggle_dims=True)
+Normalize = ModuleWrapper(NormalizeModule, toggle_dims=False)
+WhitenData = ModuleWrapper(WhitenDataModule, toggle_dims=False)
+GaussianMaskSplitter = ModuleWrapper(GaussianMaskSplitterModule, toggle_dims=True)
+UniformMaskSplitter = ModuleWrapper(UniformMaskSplitterModule, toggle_dims=True)
+Displacement = ModuleWrapper(DisplacementModule, toggle_dims=True)
+RandomElasticDeformation = ModuleWrapper(RandomElasticDeformationModule, toggle_dims=True)
+
+
+class ToTensor(DirectTransform):
+ """Transforms all np.array-like values in sample to torch.tensors."""
+
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
+ """Calls :class:`ToTensor`.
+
+ Parameters
+ ----------
+ sample: dict[str, Any]
+ Contains key 'kspace' with value a np.array of shape (coil, height, width) (2D)
+ or (coil, slice, height, width) (3D)
+
+ Returns
+ -------
+ sample: dict[str, Any]
+ Contains key 'kspace' with value a torch.Tensor of shape (coil, height, width) (2D)
+ or (coil, slice, height, width) (3D)
+ """
+
+ ndim = sample["kspace"].ndim - 1
+
+ if ndim not in [2, 3]:
+ raise ValueError(f"Can only cast 2D and 3D data (+coil) to tensor. Got {ndim}.")
+
+ # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2)
+ sample["kspace"] = T.to_tensor(sample["kspace"]).float()
+ # Sensitivity maps are not necessarily available in the dataset.
+ if "initial_kspace" in sample:
+ # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2)
+ sample["initial_kspace"] = T.to_tensor(sample["initial_kspace"]).float()
+ if "initial_image" in sample:
+ # Shape: 2D: (height, width), 3D: (slice, height, width)
+ sample["initial_image"] = T.to_tensor(sample["initial_image"]).float()
+
+ if "sensitivity_map" in sample:
+ # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2)
+ sample["sensitivity_map"] = T.to_tensor(sample["sensitivity_map"]).float()
+ if "target" in sample:
+ # Shape: 2D: (coil, height, width), 3D: (coil, slice, height, width)
+ sample["target"] = torch.from_numpy(sample["target"]).float()
+ if "sampling_mask" in sample:
+ sample["sampling_mask"] = torch.from_numpy(sample["sampling_mask"]).bool()
+ if "acs_mask" in sample:
+ sample["acs_mask"] = torch.from_numpy(sample["acs_mask"]).bool()
+ if "scaling_factor" in sample:
+ sample["scaling_factor"] = torch.tensor(sample["scaling_factor"]).float()
+ if "loglikelihood_scaling" in sample:
+ # Shape: (coil, )
+ sample["loglikelihood_scaling"] = torch.from_numpy(np.asarray(sample["loglikelihood_scaling"])).float()
+
+ return sample
+
+
+class RegistrationSimulateReferenceType(DirectEnum):
+ FROM_KEY = "from_key"
+ ELASTIC = "elastic"
+
+
+# pylint: disable=too-many-arguments
+def build_supervised_mri_transforms(
+ forward_operator: Callable,
+ backward_operator: Callable,
+ mask_func: Optional[Callable],
+ target_acceleration: Optional[float] = None,
+ crop: Optional[Union[tuple[int, int], str]] = None,
+ crop_type: Optional[str] = "uniform",
+ rescale: Optional[Union[tuple[int, int], list[int]]] = None,
+ rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST,
+ rescale_2d_if_3d: Optional[bool] = False,
+ pad: Optional[Union[tuple[int, int], list[int]]] = None,
+ image_center_crop: bool = True,
+ random_rotation_degrees: Optional[Sequence[int]] = (-90, 90),
+ random_rotation_probability: float = 0.0,
+ random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM,
+ random_flip_probability: float = 0.0,
+ random_reverse_probability: float = 0.0,
+ padding_eps: float = 0.0001,
+ estimate_body_coil_image: bool = False,
+ estimate_sensitivity_maps: bool = True,
+ sensitivity_maps_type: SensitivityMapType = SensitivityMapType.RSS_ESTIMATE,
+ sensitivity_maps_gaussian: Optional[float] = None,
+ sensitivity_maps_espirit_threshold: Optional[float] = 0.05,
+ sensitivity_maps_espirit_kernel_size: Optional[int] = 6,
+ sensitivity_maps_espirit_crop: Optional[float] = 0.95,
+ sensitivity_maps_espirit_max_iters: Optional[int] = 30,
+ use_acs_as_mask: bool = False,
+ delete_acs: bool = True,
+ delete_kspace: bool = True,
+ image_recon_type: ReconstructionType = ReconstructionType.RSS,
+ compress_coils: Optional[int] = None,
+ pad_coils: Optional[int] = None,
+ scaling_key: TransformKey = TransformKey.MASKED_KSPACE,
+ scale_percentile: Optional[float] = 0.99,
+ registration: bool = False,
+ registration_simulate_reference: Optional[RegistrationSimulateReferenceType] = None,
+ registration_simulate_elastic_sigma: float = 3.0,
+ registration_simulate_elastic_points: int = 3,
+ registration_simulate_elastic_rotate: float = 0.0,
+ registration_simulate_elastic_zoom: float = 0.0,
+ registration_estimate_displacement: bool = True,
+ registration_simulate_reference_from_key_index: int = 0,
+ registration_moving_key: TransformKey = TransformKey.TARGET,
+ demons_filter_type: DemonsFilterType = DemonsFilterType.SYMMETRIC_FORCES,
+ demons_num_iterations: int = 100,
+ demons_smooth_displacement_field: bool = True,
+ demons_standard_deviations: float = 1.5,
+ demons_intensity_difference_threshold: Optional[float] = None,
+ demons_maximum_rms_error: Optional[float] = None,
+ use_seed: bool = True,
+) -> DirectTransform:
+ r"""Builds supervised MRI transforms.
+
+ More specifically, the following transformations are applied:
+
+ * Converts input to (complex-valued) tensor.
+ * Applies k-space (center) crop if requested.
+ * Applies k-space rescaling if requested.
+ * Applies k-space padding if requested.
+ * Applies random augmentations (rotation, flip, reverse) if requested.
+ * Adds a sampling mask if `mask_func` is defined.
+ * Compreses the coil dimension if requested.
+ * Pads the coil dimension if requested.
+ * Adds coil sensitivities and / or the body coil_image
+ * Masks the fully sampled k-space, if there is a mask function or a mask in the sample.
+ * Computes a scaling factor based on the masked k-space and normalizes data.
+ * Computes a target (image).
+ * Deletes the acs mask and the fully sampled k-space if requested.
+
+ Parameters
+ ----------
+ forward_operator : Callable
+ The forward operator, e.g. some form of FFT (centered or uncentered).
+ backward_operator : Callable
+ The backward operator, e.g. some form of inverse FFT (centered or uncentered).
+ mask_func : Callable or None
+ A function which creates a sampling mask of the appropriate shape.
+ target_acceleration : float, optional
+ Target acceleration factor. Default: None.
+ crop : tuple[int, int] or str, Optional
+ If not None, this will transform the "kspace" to an image domain, crop it, and transform it back.
+ If a tuple of integers is given then it will crop the backprojected kspace to that size. If
+ "reconstruction_size" is given, then it will crop the backprojected kspace according to it, but
+ a key "reconstruction_size" must be present in the sample. Default: None.
+ crop_type : Optional[str]
+ Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform".
+ rescale : tuple or list, optional
+ If not None, this will transform the "kspace" to the image domain, rescale it, and transform it back.
+ Must correspond to (height, width). This is ignored if `rescale` is None. Default: None.
+ It is not recommended to be used in combination with `crop`.
+ rescale_mode : RescaleMode
+ Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR,
+ RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are
+ supported for 2D or 3D data. Default: RescaleMode.NEAREST.
+ rescale_2d_if_3d : bool, optional
+ If True and k-space data is 3D, rescaling will be done only on the height
+ and width dimensions, by combining the slice/time dimension with the batch dimension.
+ This is ignored if `rescale` is None. Default: False.
+ pad : tuple or list, optional
+ If not None, this will zero-pad the "kspace" to the given size. Must correspond to (height, width)
+ or (slice/time, height, width). Default: None.
+ image_center_crop : bool
+ If True the backprojected kspace will be cropped around the center, otherwise randomly.
+ This will be ignored if `crop` is None. Default: True.
+ random_rotation_degrees : Sequence[int], optional
+ Default: (-90, 90).
+ random_rotation_probability : float, optional
+ If greater than 0.0, random rotations will be applied of `random_rotation_degrees` degrees, with probability
+ `random_rotation_probability`. Default: 0.0.
+ random_flip_type : RandomFlipType, optional
+ Default: RandomFlipType.RANDOM.
+ random_flip_probability : float, optional
+ If greater than 0.0, random rotation of `random_flip_type` type, with probability `random_flip_probability`.
+ Default: 0.0.
+ random_reverse_probability : float
+ If greater than 0.0, will perform random reversion along the time or slice dimension (2) with probability
+ `random_reverse_probability`. Default: 0.0.
+ padding_eps: float
+ Padding epsilon. Default: 0.0001.
+ estimate_body_coil_image : bool
+ Estimate body coil image. Default: False.
+ estimate_sensitivity_maps : bool
+ Estimate sensitivity maps using the acs region. Default: True.
+ sensitivity_maps_type: sensitivity_maps_type
+ Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or SensitivityMapType.ESPIRIT.
+ Will be ignored if `estimate_sensitivity_maps` is False. Default: SensitivityMapType.RSS_ESTIMATE.
+ sensitivity_maps_gaussian : float
+ Optional sigma for gaussian weighting of sensitivity map.
+ sensitivity_maps_espirit_threshold : float, optional
+ Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
+ Default: 0.05.
+ sensitivity_maps_espirit_kernel_size : int, optional
+ Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 6.
+ sensitivity_maps_espirit_crop : float, optional
+ Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 0.95.
+ sensitivity_maps_espirit_max_iters : int, optional
+ Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30.
+ use_acs_as_mask : bool
+ If True, will use the acs region as the mask. Default: False.
+ delete_acs : bool
+ If True will delete key `acs_mask`. Default: True.
+ delete_kspace : bool
+ If True will delete key `kspace` (fully sampled k-space). Default: True.
+ image_recon_type : ReconstructionType
+ Type to reconstruct target image. Default: ReconstructionType.RSS.
+ compress_coils : int, optional
+ Number of coils to compress input k-space. It is not recommended to be used in combination with `pad_coils`.
+ Default: None.
+ pad_coils : int
+ Number of coils to pad data to.
+ scaling_key : TransformKey
+ Key in sample to scale scalable items in sample. Default: TransformKey.MASKED_KSPACE.
+ scale_percentile : float, optional
+ Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99.
+ registration : bool
+ If True, will compute a displacement field between the target and the moving image. Default: False.
+ registration_simulate_reference : RegistrationSimulateReferenceType
+ If not None, will simulate a reference image for displacement field computation. Otherwise, this expects a key
+ in the sample. Can be RegistrationSimulateReferenceType.FROM_KEY or RegistrationSimulateReferenceType.ELASTIC.
+ Default: None.
+ registration_simulate_elastic_sigma : float
+ Standard deviation for the elastic simulation. Default: 3.0.
+ registration_simulate_elastic_points : int
+ Number of points for the elastic simulation. Default: 3.
+ registration_simulate_elastic_rotate : float
+ Rotation for the elastic simulation. Default: 0.0.
+ registration_estimate_displacement : bool
+ If True, will estimate the displacement field between the target and the moving image using the
+ demons algorithm. Default: True
+ registration_simulate_elastic_zoom : float
+ Zoom for the elastic simulation. Default: 0.0.
+ registration_simulate_reference_from_key_index : int
+ Index to drop from the key to simulate the reference image. Default: 0.
+ demons_filter_type : DemonsFilterType
+ Type of filter to apply to the displacement field. Default: DemonsFilterType.SYMMETRIC_FORCES.
+ demons_num_iterations : int
+ Number of iterations for the demons algorithm. Default: 100.
+ demons_smooth_displacement_field : bool
+ If True, will smooth the displacement field. Default: True.
+ demons_standard_deviations : float
+ Standard deviation for the smoothing of the displacement field. Default: 1.5.
+ demons_intensity_difference_threshold : float, optional
+ Intensity difference threshold for the demons algorithm. Default: None.
+ demons_maximum_rms_error : float, optional
+ Maximum RMS error for the demons algorithm. Default: None.
+ use_seed : bool
+ If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
+ the same mask every time. Default: True.
+
+ Returns
+ -------
+ DirectTransform
+ An MRI transformation object.
+ """
+ mri_transforms: list[Callable] = [ToTensor()]
+ if crop:
+ mri_transforms += [
+ CropKspace(
+ crop=crop,
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ image_space_center_crop=image_center_crop,
+ random_crop_sampler_type=crop_type,
+ random_crop_sampler_use_seed=use_seed,
+ )
+ ]
+ if rescale:
+ mri_transforms += [
+ RescaleKspace(
+ shape=rescale,
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ rescale_mode=rescale_mode,
+ rescale_2d_if_3d=rescale_2d_if_3d,
+ kspace_key=KspaceKey.KSPACE,
+ )
+ ]
+ if pad:
+ mri_transforms += [
+ PadKspace(
+ pad_shape=pad,
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ kspace_key=KspaceKey.KSPACE,
+ )
+ ]
+ if random_rotation_probability > 0.0:
+ mri_transforms += [
+ RandomRotation(
+ degrees=random_rotation_degrees,
+ p=random_rotation_probability,
+ keys_to_rotate=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP),
+ )
+ ]
+ if random_flip_probability > 0.0:
+ mri_transforms += [
+ RandomFlip(
+ flip=random_flip_type,
+ p=random_flip_probability,
+ keys_to_flip=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP),
+ )
+ ]
+ if random_reverse_probability > 0.0:
+ mri_transforms += [
+ RandomReverse(
+ p=random_reverse_probability,
+ keys_to_reverse=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP),
+ )
+ ]
+ if padding_eps > 0.0:
+ mri_transforms += [
+ ComputeZeroPadding(KspaceKey.KSPACE, "padding", padding_eps),
+ ApplyZeroPadding(KspaceKey.KSPACE, "padding"),
+ ]
+ if mask_func:
+ mri_transforms += [
+ CreateSamplingMask(
+ mask_func,
+ shape=(None if (isinstance(crop, str)) else crop),
+ use_seed=use_seed,
+ return_acs=estimate_sensitivity_maps,
+ ),
+ ]
+ if use_acs_as_mask:
+ mri_transforms += [CopyKeys(keys=[TransformKey.ACS_MASK], new_keys=[TransformKey.SAMPLING_MASK])]
+ if target_acceleration:
+ mri_transforms += [AddTargetAcceleration(target_acceleration)]
+ if compress_coils:
+ mri_transforms += [CompressCoil(num_coils=compress_coils, kspace_key=KspaceKey.KSPACE)]
+ if pad_coils:
+ mri_transforms += [PadCoilDimension(pad_coils=pad_coils, key=KspaceKey.KSPACE)]
+
+ if estimate_body_coil_image and mask_func is not None:
+ mri_transforms.append(EstimateBodyCoilImage(mask_func, backward_operator=backward_operator, use_seed=use_seed))
+ mri_transforms += [
+ ApplyMask(
+ sampling_mask_key=TransformKey.ACS_MASK,
+ input_kspace_key=KspaceKey.KSPACE,
+ target_kspace_key=KspaceKey.ACS_KSPACE,
+ ),
+ ]
+ if estimate_sensitivity_maps:
+ mri_transforms += [
+ EstimateSensitivityMap(
+ kspace_key=KspaceKey.ACS_KSPACE,
+ backward_operator=backward_operator,
+ type_of_map=sensitivity_maps_type,
+ gaussian_sigma=sensitivity_maps_gaussian,
+ espirit_threshold=sensitivity_maps_espirit_threshold,
+ espirit_kernel_size=sensitivity_maps_espirit_kernel_size,
+ espirit_crop=sensitivity_maps_espirit_crop,
+ espirit_max_iters=sensitivity_maps_espirit_max_iters,
+ )
+ ]
+ mri_transforms += [
+ ApplyMask(
+ sampling_mask_key=TransformKey.SAMPLING_MASK,
+ input_kspace_key=KspaceKey.KSPACE,
+ target_kspace_key=KspaceKey.MASKED_KSPACE,
+ ),
+ ]
+ if registration:
+ if registration_simulate_reference is not None:
+ mri_transforms += [
+ DropIndex(
+ keys=[
+ TransformKey.KSPACE,
+ TransformKey.ACS_KSPACE,
+ TransformKey.MASKED_KSPACE,
+ TransformKey.ACS_MASK,
+ TransformKey.SAMPLING_MASK,
+ TransformKey.PADDING,
+ TransformKey.SENSITIVITY_MAP,
+ TransformKey.ACCELERATION,
+ TransformKey.CENTER_FRACTION,
+ ],
+ index=registration_simulate_reference_from_key_index,
+ index_dim=1,
+ store_deleted_keys=[TransformKey.REFERENCE_KSPACE],
+ )
+ ]
+ mri_transforms += [
+ ComputeScalingFactor(
+ normalize_key=scaling_key, percentile=scale_percentile, scaling_factor_key=TransformKey.SCALING_FACTOR
+ ),
+ Normalize(
+ scaling_factor_key=TransformKey.SCALING_FACTOR,
+ keys_to_normalize=[
+ KspaceKey.ACS_KSPACE,
+ KspaceKey.KSPACE,
+ KspaceKey.MASKED_KSPACE,
+ KspaceKey.REFERENCE_KSPACE,
+ ], # Only these two keys are in the sample here
+ ),
+ ]
+ mri_transforms += [
+ ComputeImage(
+ kspace_key=KspaceKey.KSPACE,
+ target_key=TransformKey.TARGET,
+ backward_operator=backward_operator,
+ type_reconstruction=image_recon_type,
+ )
+ ]
+ if registration:
+ if registration_simulate_reference is not None:
+ mri_transforms += [
+ ComputeImage(
+ kspace_key=KspaceKey.REFERENCE_KSPACE,
+ target_key=TransformKey.REFERENCE_IMAGE,
+ backward_operator=backward_operator,
+ type_reconstruction=image_recon_type,
+ ),
+ SqueezeKey(keys=[TransformKey.REFERENCE_IMAGE], dim=0),
+ ]
+ if registration_simulate_reference == RegistrationSimulateReferenceType.ELASTIC:
+ mri_transforms += [
+ RandomElasticDeformation(
+ image_key=TransformKey.REFERENCE_IMAGE,
+ target_key=TransformKey.REFERENCE_IMAGE,
+ use_seed=use_seed,
+ sigma=registration_simulate_elastic_sigma,
+ points=registration_simulate_elastic_points,
+ rotate=registration_simulate_elastic_rotate,
+ zoom=registration_simulate_elastic_zoom,
+ )
+ ]
+ if registration_estimate_displacement:
+ mri_transforms += [
+ Displacement(
+ transform_type=DisplacementTransformType.MULTISCALE_DEMONS,
+ demons_filter_type=demons_filter_type,
+ demons_num_iterations=demons_num_iterations,
+ demons_smooth_displacement_field=demons_smooth_displacement_field,
+ demons_standard_deviations=demons_standard_deviations,
+ demons_intensity_difference_threshold=demons_intensity_difference_threshold,
+ demons_maximum_rms_error=demons_maximum_rms_error,
+ reference_image_key=TransformKey.REFERENCE_IMAGE,
+ moving_image_key=registration_moving_key,
+ )
+ ]
+ if delete_acs:
+ mri_transforms += [DeleteKeys(keys=[TransformKey.ACS_MASK, KspaceKey.ACS_KSPACE])]
+ if delete_kspace:
+ mri_transforms += [DeleteKeys(keys=[KspaceKey.KSPACE])]
+
+ return Compose(mri_transforms)
+
+
+class TransformsType(DirectEnum):
+ SUPERVISED = "supervised"
+ SSL_SSDU = "ssl_ssdu"
+
+
+# pylint: disable=too-many-arguments
+def build_mri_transforms(
+ forward_operator: Callable,
+ backward_operator: Callable,
+ mask_func: Optional[Callable],
+ target_acceleration: Optional[float] = None,
+ crop: Optional[Union[tuple[int, int], str]] = None,
+ crop_type: Optional[str] = "uniform",
+ rescale: Optional[Union[tuple[int, int], list[int]]] = None,
+ rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST,
+ rescale_2d_if_3d: Optional[bool] = False,
+ pad: Optional[Union[tuple[int, int], list[int]]] = None,
+ image_center_crop: bool = True,
+ random_rotation_degrees: Optional[Sequence[int]] = (-90, 90),
+ random_rotation_probability: float = 0.0,
+ random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM,
+ random_flip_probability: float = 0.0,
+ random_reverse_probability: float = 0.0,
+ padding_eps: float = 0.0001,
+ estimate_body_coil_image: bool = False,
+ estimate_sensitivity_maps: bool = True,
+ sensitivity_maps_type: SensitivityMapType = SensitivityMapType.RSS_ESTIMATE,
+ sensitivity_maps_gaussian: Optional[float] = None,
+ sensitivity_maps_espirit_threshold: Optional[float] = 0.05,
+ sensitivity_maps_espirit_kernel_size: Optional[int] = 6,
+ sensitivity_maps_espirit_crop: Optional[float] = 0.95,
+ sensitivity_maps_espirit_max_iters: Optional[int] = 30,
+ use_acs_as_mask: bool = False,
+ delete_acs: bool = True,
+ delete_kspace: bool = True,
+ image_recon_type: ReconstructionType = ReconstructionType.RSS,
+ compress_coils: Optional[int] = None,
+ pad_coils: Optional[int] = None,
+ scaling_key: TransformKey = TransformKey.MASKED_KSPACE,
+ scale_percentile: Optional[float] = 0.99,
+ registration: bool = False,
+ registration_simulate_reference: Optional[RegistrationSimulateReferenceType] = None,
+ registration_simulate_elastic_sigma: float = 3.0,
+ registration_simulate_elastic_points: int = 3,
+ registration_simulate_elastic_rotate: float = 0.0,
+ registration_simulate_elastic_zoom: float = 0.0,
+ registration_estimate_displacement: bool = True,
+ registration_simulate_reference_from_key_index: int = 0,
+ registration_moving_key: TransformKey = TransformKey.TARGET,
+ demons_filter_type: DemonsFilterType = DemonsFilterType.SYMMETRIC_FORCES,
+ demons_num_iterations: int = 100,
+ demons_smooth_displacement_field: bool = True,
+ demons_standard_deviations: float = 1.5,
+ demons_intensity_difference_threshold: Optional[float] = None,
+ demons_maximum_rms_error: Optional[float] = None,
+ use_seed: bool = True,
+ transforms_type: Optional[TransformsType] = TransformsType.SUPERVISED,
+ mask_split_ratio: Union[float, list[float], tuple[float, ...]] = 0.4,
+ mask_split_acs_region: Union[list[int], tuple[int, int]] = (0, 0),
+ mask_split_keep_acs: Optional[bool] = False,
+ mask_split_type: MaskSplitterType = MaskSplitterType.GAUSSIAN,
+ mask_split_gaussian_std: float = 3.0,
+ mask_split_half_direction: HalfSplitType = HalfSplitType.VERTICAL,
+) -> DirectTransform:
+ r"""Build transforms for MRI.
+
+ More specifically, the following transformations are applied:
+
+ * Converts input to (complex-valued) tensor.
+ * Applies k-space (center) crop if requested.
+ * Applies k-space rescaling if requested.
+ * Applies k-space padding if requested.
+ * Applies random augmentations (rotation, flip, reverse) if requested.
+ * Adds a sampling mask if `mask_func` is defined.
+ * Compreses the coil dimension if requested.
+ * Pads the coil dimension if requested.
+ * Adds coil sensitivities and / or the body coil_image
+ * Masks the fully sampled k-space, if there is a mask function or a mask in the sample.
+ * Computes a scaling factor based on the masked k-space and normalizes data.
+ * Computes a target (image).
+ * Deletes the acs mask and the fully sampled k-space if requested.
+ * Splits the mask if requested for self-supervised learning.
+
+ Parameters
+ ----------
+ forward_operator : Callable
+ The forward operator, e.g. some form of FFT (centered or uncentered).
+ backward_operator : Callable
+ The backward operator, e.g. some form of inverse FFT (centered or uncentered).
+ mask_func : Callable or None
+ A function which creates a sampling mask of the appropriate shape.
+ target_acceleration : float, optional
+ Target acceleration factor. Default: None.
+ crop : tuple[int, int] or str, Optional
+ If not None, this will transform the "kspace" to an image domain, crop it, and transform it back.
+ If a tuple of integers is given then it will crop the backprojected kspace to that size. If
+ "reconstruction_size" is given, then it will crop the backprojected kspace according to it, but
+ a key "reconstruction_size" must be present in the sample. Default: None.
+ crop_type : Optional[str]
+ Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform".
+ rescale : tuple or list, optional
+ If not None, this will transform the "kspace" to the image domain, rescale it, and transform it back.
+ Must correspond to (height, width). This is ignored if `rescale` is None. Default: None.
+ It is not recommended to be used in combination with `crop`.
+ rescale_mode : RescaleMode
+ Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR,
+ RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are
+ supported for 2D or 3D data. Default: RescaleMode.NEAREST.
+ rescale_2d_if_3d : bool, optional
+ If True and k-space data is 3D, rescaling will be done only on the height
+ and width dimensions, by combining the slice/time dimension with the batch dimension.
+ This is ignored if `rescale` is None. Default: False.
+ pad : tuple or list, optional
+ If not None, this will zero-pad the "kspace" to the given size. Must correspond to (height, width)
+ or (slice/time, height, width). Default: None.
+ image_center_crop : bool
+ If True the backprojected kspace will be cropped around the center, otherwise randomly.
+ This will be ignored if `crop` is None. Default: True.
+ random_rotation_degrees : Sequence[int], optional
+ Default: (-90, 90).
+ random_rotation_probability : float, optional
+ If greater than 0.0, random rotations will be applied of `random_rotation_degrees` degrees, with probability
+ `random_rotation_probability`. Default: 0.0.
+ random_flip_type : RandomFlipType, optional
+ Default: RandomFlipType.RANDOM.
+ random_flip_probability : float, optional
+ If greater than 0.0, random rotation of `random_flip_type` type, with probability `random_flip_probability`.
+ Default: 0.0.
+ random_reverse_probability : float
+ If greater than 0.0, will perform random reversion along the time or slice dimension (2) with probability
+ `random_reverse_probability`. Default: 0.0.
+ padding_eps: float
+ Padding epsilon. Default: 0.0001.
+ estimate_body_coil_image : bool
+ Estimate body coil image. Default: False.
+ estimate_sensitivity_maps : bool
+ Estimate sensitivity maps using the acs region. Default: True.
+ sensitivity_maps_type: sensitivity_maps_type
+ Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or SensitivityMapType.ESPIRIT.
+ Will be ignored if `estimate_sensitivity_maps` is False. Default: SensitivityMapType.RSS_ESTIMATE.
+ sensitivity_maps_gaussian : float
+ Optional sigma for gaussian weighting of sensitivity map.
+ sensitivity_maps_espirit_threshold : float, optional
+ Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
+ Default: 0.05.
+ sensitivity_maps_espirit_kernel_size : int, optional
+ Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 6.
+ sensitivity_maps_espirit_crop : float, optional
+ Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 0.95.
+ sensitivity_maps_espirit_max_iters : int, optional
+ Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30.
+ use_acs_as_mask : bool
+ If True, will use the acs region as the mask. Default: False.
+ delete_acs : bool
+ If True will delete key `acs_mask`. Default: True.
+ delete_kspace : bool
+ If True will delete key `kspace` (fully sampled k-space). Default: True.
+ image_recon_type : ReconstructionType
+ Type to reconstruct target image. Default: ReconstructionType.RSS.
+ compress_coils : int, optional
+ Number of coils to compress input k-space. It is not recommended to be used in combination with `pad_coils`.
+ Default: None.
+ pad_coils : int
+ Number of coils to pad data to.
+ scaling_key : TransformKey
+ Key in sample to scale scalable items in sample. Default: TransformKey.MASKED_KSPACE.
+ scale_percentile : float, optional
+ Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99.
+ registration : bool
+ If True, will compute a displacement field between the target and the moving image. Default: False.
+ registration_simulate_reference : RegistrationSimulateReferenceType
+ If not None, will simulate a reference image for displacement field computation. Otherwise, this expects a key
+ in the sample. Can be RegistrationSimulateReferenceType.FROM_KEY or RegistrationSimulateReferenceType.ELASTIC.
+ Default: None.
+ registration_simulate_elastic_sigma : float
+ Standard deviation for the elastic simulation. Default: 3.0.
+ registration_simulate_elastic_points : int
+ Number of points for the elastic simulation. Default: 3.
+ registration_simulate_elastic_rotate : float
+ Rotation for the elastic simulation. Default: 0.0.
+ registration_simulate_elastic_zoom : float
+ Zoom for the elastic simulation. Default: 0.0.
+ registration_estimate_displacement : bool
+ If True, will estimate the displacement field between the target and the moving image using the
+ demons algorithm. Default: True
+ registration_simulate_reference_from_key_index : int
+ Index to drop from the key to simulate the reference image. Default: 0.
+ registration_moving_key : TransformKey
+ Key in sample to compute displacement field from. Default: TransformKey.TARGET.
+ demons_filter_type : DemonsFilterType
+ Type of filter to apply to the displacement field. Default: DemonsFilterType.SYMMETRIC_FORCES.
+ demons_num_iterations : int
+ Number of iterations for the demons algorithm. Default: 100.
+ demons_smooth_displacement_field : bool
+ If True, will smooth the displacement field. Default: True.
+ demons_standard_deviations : float
+ Standard deviation for the smoothing of the displacement field. Default: 1.5.
+ demons_intensity_difference_threshold : float, optional
+ Intensity difference threshold for the demons algorithm. Default: None.
+ demons_maximum_rms_error : float, optional
+ Maximum RMS error for the demons algorithm. Default: None.
+ use_seed : bool
+ If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
+ the same mask every time. Default: True.
+ transforms_type : TransformsType, optional
+ Can be `TransformsType.SUPERVISED` for supervised learning transforms or `TransformsType.SSL_SSDU` for
+ self-supervised learning transforms. Default: `TransformsType.SUPERVISED`.
+ mask_split_ratio : Union[float, list[float], tuple[float, ...]]
+ The ratio(s) of the sampling mask splitting. If `transforms_type` is TransformsKey.SUPERVISED, this is ignored.
+ mask_split_acs_region : Union[list[int], tuple[int, int]]
+ A rectangle for the acs region that will be used in the input mask. This applies only if `transforms_type` is
+ set to TransformsKey.SSL_SSDU. Default: (0, 0).
+ mask_split_keep_acs : Optional[bool]
+ If True, acs region according to the "acs_mask" of the sample will be used in both mask splits.
+ This applies only if `transforms_type` is set to TransformsKey.SSL_SSDU. Default: False.
+ mask_split_type : MaskSplitterType
+ How the sampling mask will be split. Can be MaskSplitterType.UNIFORM, MaskSplitterType.GAUSSIAN, or
+ MaskSplitterType.HALF. Default: MaskSplitterType.GAUSSIAN. This applies only if `transforms_type` is
+ set to TransformsKey.SSL_SSDU. Default: MaskSplitterType.GAUSSIAN.
+ mask_split_gaussian_std : float
+ Standard deviation of gaussian mask splitting. This applies only if `transforms_type` is
+ set to TransformsKey.SSL_SSDU. Ignored if `mask_split_type` is not set to MaskSplitterType.GAUSSIAN.
+ Default: 3.0.
+ mask_split_half_direction : HalfSplitType
+ Split type if `mask_split_type` is `MaskSplitterType.HALF`. Can be `HalfSplitType.VERTICAL`,
+ `HalfSplitType.HORIZONTAL`, `HalfSplitType.DIAGONAL_LEFT` or `HalfSplitType.DIAGONAL_RIGHT`.
+ This applies only if `transforms_type` is set to `TransformsKey.SSL_SSDU`. Ignored if `mask_split_type` is not
+ set to `MaskSplitterType.HALF`. Default: `HalfSplitType.VERTICAL`.
+
+ Returns
+ -------
+ DirectTransform
+ An MRI transformation object.
+ """
+ logger = logging.getLogger(build_mri_transforms.__name__)
+ logger.info("Creating %s MRI transforms.", transforms_type)
+
+ if crop and rescale:
+ logger.warning(
+ "Rescale and crop are both given. Rescale will be applied after cropping. This is not recommended."
+ )
+
+ if compress_coils and pad_coils:
+ logger.warning(
+ "Compress coils and pad coils are both given. Compress coils will be applied before padding. "
+ "This is not recommended."
+ )
+
+ mri_transforms = build_supervised_mri_transforms(
+ forward_operator=forward_operator,
+ backward_operator=backward_operator,
+ mask_func=mask_func,
+ target_acceleration=target_acceleration,
+ crop=crop,
+ crop_type=crop_type,
+ rescale=rescale,
+ rescale_mode=rescale_mode,
+ rescale_2d_if_3d=rescale_2d_if_3d,
+ pad=pad,
+ image_center_crop=image_center_crop,
+ random_rotation_degrees=random_rotation_degrees,
+ random_rotation_probability=random_rotation_probability,
+ random_flip_type=random_flip_type,
+ random_flip_probability=random_flip_probability,
+ random_reverse_probability=random_reverse_probability,
+ padding_eps=padding_eps,
+ estimate_sensitivity_maps=estimate_sensitivity_maps,
+ sensitivity_maps_type=sensitivity_maps_type,
+ estimate_body_coil_image=estimate_body_coil_image,
+ sensitivity_maps_gaussian=sensitivity_maps_gaussian,
+ sensitivity_maps_espirit_threshold=sensitivity_maps_espirit_threshold,
+ sensitivity_maps_espirit_kernel_size=sensitivity_maps_espirit_kernel_size,
+ sensitivity_maps_espirit_crop=sensitivity_maps_espirit_crop,
+ sensitivity_maps_espirit_max_iters=sensitivity_maps_espirit_max_iters,
+ use_acs_as_mask=use_acs_as_mask,
+ delete_acs=delete_acs if transforms_type == TransformsType.SUPERVISED else False,
+ delete_kspace=delete_kspace if transforms_type == TransformsType.SUPERVISED else False,
+ image_recon_type=image_recon_type,
+ compress_coils=compress_coils,
+ pad_coils=pad_coils,
+ scaling_key=scaling_key,
+ scale_percentile=scale_percentile,
+ registration=registration,
+ registration_simulate_reference=registration_simulate_reference,
+ registration_simulate_elastic_sigma=registration_simulate_elastic_sigma,
+ registration_simulate_elastic_points=registration_simulate_elastic_points,
+ registration_simulate_elastic_rotate=registration_simulate_elastic_rotate,
+ registration_simulate_elastic_zoom=registration_simulate_elastic_zoom,
+ registration_estimate_displacement=registration_estimate_displacement,
+ registration_simulate_reference_from_key_index=registration_simulate_reference_from_key_index,
+ registration_moving_key=registration_moving_key,
+ demons_filter_type=demons_filter_type,
+ demons_num_iterations=demons_num_iterations,
+ demons_smooth_displacement_field=demons_smooth_displacement_field,
+ demons_standard_deviations=demons_standard_deviations,
+ demons_intensity_difference_threshold=demons_intensity_difference_threshold,
+ demons_maximum_rms_error=demons_maximum_rms_error,
+ use_seed=use_seed,
+ ).transforms
+
+ mri_transforms += [AddBooleanKeysModule(["is_ssl"], [transforms_type != TransformsType.SUPERVISED])]
+
+ if transforms_type == TransformsType.SUPERVISED:
+ return Compose(mri_transforms)
+
+ mask_splitter_kwargs = {
+ "ratio": mask_split_ratio,
+ "acs_region": mask_split_acs_region,
+ "keep_acs": mask_split_keep_acs,
+ "use_seed": use_seed,
+ "kspace_key": KspaceKey.MASKED_KSPACE,
+ }
+ mri_transforms += [
+ (
+ GaussianMaskSplitter(**mask_splitter_kwargs, std_scale=mask_split_gaussian_std)
+ if mask_split_type == MaskSplitterType.GAUSSIAN
+ else (
+ UniformMaskSplitter(**mask_splitter_kwargs)
+ if mask_split_type == MaskSplitterType.UNIFORM
+ else HalfMaskSplitterModule(
+ **{k: v for k, v in mask_splitter_kwargs.items() if k != "ratio"},
+ direction=mask_split_half_direction,
+ )
+ )
+ ),
+ DeleteKeys([TransformKey.ACS_MASK]),
+ ]
+
+ mri_transforms += [
+ RenameKeys(
+ [
+ SSLTransformMaskPrefixes.INPUT_ + TransformKey.MASKED_KSPACE,
+ SSLTransformMaskPrefixes.TARGET_ + TransformKey.MASKED_KSPACE,
+ ],
+ ["input_kspace", "kspace"],
+ ),
+ DeleteKeys(["masked_kspace", "sampling_mask"]),
+ ] # Rename keys for SSL engine
+
+ mri_transforms += [
+ ComputeImage(
+ kspace_key=KspaceKey.KSPACE,
+ target_key=TransformKey.TARGET,
+ backward_operator=backward_operator,
+ type_reconstruction=image_recon_type,
+ )
+ ]
+
+ return Compose(mri_transforms)
diff --git a/direct/nn/mri_models.py b/direct/nn/mri_models.py
index 26ec0286..e984f526 100644
--- a/direct/nn/mri_models.py
+++ b/direct/nn/mri_models.py
@@ -210,7 +210,7 @@ def _do_iteration(
return DoIterationOutput(
output_image=(
- (output_image, registered_image)
+ (output_image, registered_image, displacement_field)
if (self.ndim == 3 and "registration_model" in self.models)
else output_image
),
@@ -998,7 +998,7 @@ def reconstruct_volumes( # type: ignore
iteration_output = self._do_iteration(data, loss_fns=loss_fns, regularizer_fns=regularizer_fns)
output = iteration_output.output_image
if "registration_model" in self.models:
- output, registered_output = output
+ output, registered_output, displacement_field = output
sampling_mask = iteration_output.sampling_mask
if sampling_mask is not None:
@@ -1012,6 +1012,7 @@ def reconstruct_volumes( # type: ignore
resolution=resolution,
complex_axis=self._complex_dim,
)
+
if "registration_model" in self.models:
registered_output_abs = _process_output(
registered_output,
@@ -1019,6 +1020,10 @@ def reconstruct_volumes( # type: ignore
resolution=resolution,
complex_axis=self._complex_dim,
)
+ if resolution is not None:
+ output_df = T.center_crop(displacement_field, resolution)
+ else:
+ output_df = displacement_field
if add_target:
target_abs = _process_output(
@@ -1044,6 +1049,7 @@ def reconstruct_volumes( # type: ignore
curr_registration_volume = torch.zeros(
*(volume_size, *registered_output_abs.shape[1:]), dtype=registered_output_abs.dtype
)
+ curr_df_volume = torch.zeros(*(volume_size, *output_df.shape[1:]), dtype=output_df.dtype)
curr_mask = (
torch.zeros(*(volume_size, *sampling_mask.shape[1:]), dtype=sampling_mask.dtype)
@@ -1061,6 +1067,8 @@ def reconstruct_volumes( # type: ignore
curr_registration_volume[instance_counter : instance_counter + output_abs.shape[0], ...] = (
registered_output_abs.cpu()
)
+ curr_df_volume[instance_counter : instance_counter + output_abs.shape[0], ...] = output_df.cpu()
+
if sampling_mask is not None:
curr_mask[instance_counter : instance_counter + output_abs.shape[0], ...] = sampling_mask.cpu()
if add_target:
@@ -1088,7 +1096,7 @@ def reconstruct_volumes( # type: ignore
del data
if "registration_model" in self.models:
- curr_volume = (curr_volume, curr_registration_volume)
+ curr_volume = (curr_volume, curr_registration_volume, curr_df_volume)
if add_target and "registration_model" in self.models:
curr_target = (curr_target, curr_registration_target)
@@ -1123,9 +1131,10 @@ def reconstruct_and_evaluate( # type: ignore
):
volume, target, mask, volume_loss_dict, filename = output
if isinstance(volume, tuple):
- volume, registration_volume = volume
+ volume, registration_volume, displacement_field = volume
else:
registration_volume = None
+ displacement_field = None
if isinstance(target, tuple):
target, registration_target = target
else:
@@ -1176,7 +1185,11 @@ def reconstruct_and_evaluate( # type: ignore
inf_losses.append(volume_loss_dict)
out.append(
- (volume if registration_volume is None else (volume, registration_volume), mask, filename),
+ (
+ volume if registration_volume is None else (volume, registration_volume, displacement_field),
+ mask,
+ filename,
+ ),
)
# Average loss dict
@@ -1235,7 +1248,7 @@ def evaluate( # type: ignore
):
volume, target, mask, volume_loss_dict, filename = output
if isinstance(volume, tuple):
- volume, registration_volume = volume
+ volume, registration_volume, _ = volume
else:
registration_volume = None
if isinstance(target, tuple):
diff --git a/direct/nn/registration/config.py b/direct/nn/registration/config.py
index ae9f334c..36395a68 100644
--- a/direct/nn/registration/config.py
+++ b/direct/nn/registration/config.py
@@ -13,6 +13,8 @@
class RegistrationModelConfig(ModelConfig):
warp_num_integration_steps: int = 1
train_end_to_end: bool = False
+ # TODO: Needs to be defined outside of the config
+ reg_loss_factor: float = 1.0 # Regularization loss weight factor
@dataclass
diff --git a/direct/nn/vsharp/config.py b/direct/nn/vsharp/config.py
index 58e37d54..b18cbd2f 100644
--- a/direct/nn/vsharp/config.py
+++ b/direct/nn/vsharp/config.py
@@ -1,53 +1,56 @@
-# Copyright (c) DIRECT Contributors
-
-from __future__ import annotations
-
-from dataclasses import dataclass
-
-from direct.config.defaults import ModelConfig
-from direct.nn.types import ActivationType, InitType, ModelName
-
-
-@dataclass
-class VSharpNetConfig(ModelConfig):
- num_steps: int = 10
- num_steps_dc_gd: int = 8
- image_init: InitType = InitType.SENSE
- no_parameter_sharing: bool = True
- auxiliary_steps: int = 0
- image_model_architecture: ModelName = ModelName.UNET
- initializer_channels: tuple[int, ...] = (32, 32, 64, 64)
- initializer_dilations: tuple[int, ...] = (1, 1, 2, 4)
- initializer_multiscale: int = 1
- initializer_activation: ActivationType = ActivationType.PRELU
- image_resnet_hidden_channels: int = 128
- image_resnet_num_blocks: int = 15
- image_resnet_batchnorm: bool = True
- image_resnet_scale: float = 0.1
- image_unet_num_filters: int = 32
- image_unet_num_pool_layers: int = 4
- image_unet_dropout: float = 0.0
- image_didn_hidden_channels: int = 16
- image_didn_num_dubs: int = 6
- image_didn_num_convs_recon: int = 9
- image_conv_hidden_channels: int = 64
- image_conv_n_convs: int = 15
- image_conv_activation: str = ActivationType.RELU
- image_conv_batchnorm: bool = False
-
-
-@dataclass
-class VSharpNet3DConfig(ModelConfig):
- num_steps: int = 8
- num_steps_dc_gd: int = 6
- image_init: InitType = InitType.SENSE
- no_parameter_sharing: bool = True
- auxiliary_steps: int = -1
- initializer_channels: tuple[int, ...] = (32, 32, 64, 64)
- initializer_dilations: tuple[int, ...] = (1, 1, 2, 4)
- initializer_multiscale: int = 1
- initializer_activation: ActivationType = ActivationType.PRELU
- unet_num_filters: int = 32
- unet_num_pool_layers: int = 4
- unet_dropout: float = 0.0
- unet_norm: bool = False
+# Copyright (c) DIRECT Contributors
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from direct.config.defaults import ModelConfig
+from direct.nn.types import ActivationType, InitType, ModelName
+from direct.nn.vsharp.vsharp import LagrangeMultipliersInitialization
+
+
+@dataclass
+class VSharpNetConfig(ModelConfig):
+ num_steps: int = 10
+ num_steps_dc_gd: int = 8
+ image_init: InitType = InitType.SENSE
+ no_parameter_sharing: bool = True
+ auxiliary_steps: int = 0
+ lagrange_initialization: LagrangeMultipliersInitialization = LagrangeMultipliersInitialization.LEARNED
+ image_model_architecture: ModelName = ModelName.UNET
+ initializer_channels: tuple[int, ...] = (32, 32, 64, 64)
+ initializer_dilations: tuple[int, ...] = (1, 1, 2, 4)
+ initializer_multiscale: int = 1
+ initializer_activation: ActivationType = ActivationType.PRELU
+ image_resnet_hidden_channels: int = 128
+ image_resnet_num_blocks: int = 15
+ image_resnet_batchnorm: bool = True
+ image_resnet_scale: float = 0.1
+ image_unet_num_filters: int = 32
+ image_unet_num_pool_layers: int = 4
+ image_unet_dropout: float = 0.0
+ image_didn_hidden_channels: int = 16
+ image_didn_num_dubs: int = 6
+ image_didn_num_convs_recon: int = 9
+ image_conv_hidden_channels: int = 64
+ image_conv_n_convs: int = 15
+ image_conv_activation: str = ActivationType.RELU
+ image_conv_batchnorm: bool = False
+
+
+@dataclass
+class VSharpNet3DConfig(ModelConfig):
+ num_steps: int = 8
+ num_steps_dc_gd: int = 6
+ image_init: InitType = InitType.SENSE
+ no_parameter_sharing: bool = True
+ auxiliary_steps: int = -1
+ lagrange_initialization: LagrangeMultipliersInitialization = LagrangeMultipliersInitialization.LEARNED
+ initializer_channels: tuple[int, ...] = (32, 32, 64, 64)
+ initializer_dilations: tuple[int, ...] = (1, 1, 2, 4)
+ initializer_multiscale: int = 1
+ initializer_activation: ActivationType = ActivationType.PRELU
+ unet_num_filters: int = 32
+ unet_num_pool_layers: int = 4
+ unet_dropout: float = 0.0
+ unet_norm: bool = False
diff --git a/direct/nn/vsharp/vsharp.py b/direct/nn/vsharp/vsharp.py
index fb417f99..1f92b5bb 100644
--- a/direct/nn/vsharp/vsharp.py
+++ b/direct/nn/vsharp/vsharp.py
@@ -1,612 +1,640 @@
-# Copyright (c) DIRECT Contributors
-
-"""This module provides the implementation of vSHARP model.
-
-Most specifically, vSHARP is the variable Splitting Half-quadratic ADMM algorithm for Reconstruction
-of inverse-Problems (vSHARPP) model as presented in [1]_.
-
-
-References
-----------
-
-.. [1] George Yiasemis et. al. vSHARP: variable Splitting Half-quadratic ADMM algorithm for Reconstruction
- of inverse-Problems (2023). https://arxiv.org/abs/2309.09954.
-
-"""
-
-
-from __future__ import annotations
-
-from typing import Any, Callable
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-from direct.constants import COMPLEX_SIZE
-from direct.data.transforms import apply_mask, expand_operator, reduce_operator
-from direct.nn.get_nn_model_config import ModelName, _get_model_config, _get_relu_activation
-from direct.nn.types import ActivationType, InitType
-from direct.nn.unet.unet_3d import NormUnetModel3d, UnetModel3d
-
-
-class LagrangeMultipliersInitializer(nn.Module):
- """A convolutional neural network model that initializers the Lagrange multiplier of the :class:`vSHARPNet` [1]_.
-
- More specifically, it produces an initial value for the Lagrange Multiplier based on the zero-filled image:
-
- .. math::
-
- u^0 = \mathcal{G}_{\psi}(x^0).
-
- References
- ----------
- .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction
- of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- channels: tuple[int, ...],
- dilations: tuple[int, ...],
- multiscale_depth: int = 1,
- activation: ActivationType = ActivationType.PRELU,
- ) -> None:
- """Inits :class:`LagrangeMultipliersInitializer`.
-
- Parameters
- ----------
- in_channels : int
- Number of input channels.
- out_channels : int
- Number of output channels.
- channels : tuple of ints
- Tuple of integers specifying the number of output channels for each convolutional layer in the network.
- dilations : tuple of ints
- Tuple of integers specifying the dilation factor for each convolutional layer in the network.
- multiscale_depth : int
- Number of multiscale features to include in the output. Default: 1.
- activation : ActivationType
- Activation function to use on the output. Default: ActivationType.PRELU.
- """
- super().__init__()
-
- # Define convolutional blocks
- self.conv_blocks = nn.ModuleList()
- tch = in_channels
- for curr_channels, curr_dilations in zip(channels, dilations):
- block = nn.Sequential(
- nn.ReplicationPad2d(curr_dilations),
- nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations),
- )
- tch = curr_channels
- self.conv_blocks.append(block)
-
- # Define output block
- tch = np.sum(channels[-multiscale_depth:])
- block = nn.Conv2d(tch, out_channels, 1, padding=0)
- self.out_block = nn.Sequential(block)
-
- self.multiscale_depth = multiscale_depth
-
- self.activation = _get_relu_activation(activation)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass of :class:`LagrangeMultipliersInitializer`.
-
- Parameters
- ----------
- x : torch.Tensor
- Input tensor of shape (batch_size, in_channels, height, width).
-
- Returns
- -------
- torch.Tensor
- Output tensor of shape (batch_size, out_channels, height, width).
- """
-
- features = []
- for block in self.conv_blocks:
- x = F.relu(block(x), inplace=True)
- if self.multiscale_depth > 1:
- features.append(x)
-
- if self.multiscale_depth > 1:
- x = torch.cat(features[-self.multiscale_depth :], dim=1)
-
- return self.activation(self.out_block(x))
-
-
-class VSharpNet(nn.Module):
- """
- Variable Splitting Half-quadratic ADMM algorithm for Reconstruction of Parallel MRI [1]_.
-
- Variable Splitting Half Quadratic VSharpNet is a deep learning model that solves
- the augmented Lagrangian derivation of the variable half quadratic splitting problem
- using ADMM (Alternating Direction Method of Multipliers). It is specifically designed
- for solving inverse problems in magnetic resonance imaging (MRI).
-
- The VSharpNet model incorporates an iterative optimization algorithm that consists of
- three steps: z-step, x-step, and u-step. These steps are detailed mathematically as follows:
-
- .. math::
-
- z^{t+1} = \mathrm{argmin}_{z} \\lambda \mathcal{G}(z) + \\frac{\\rho}{2} || x^{t} - z +
- \\frac{u^t}{\\rho} ||_2^2 \\quad \mathrm{[z-step]}
-
- .. math::
-
- x^{t+1} = \mathrm{argmin}_{x} \\frac{1}{2} || \mathcal{A}_{\mathbf{U},\mathbf{S}}(x) - \\tilde{y} ||_2^2 +
- \\frac{\\rho}{2} || x - z^{t+1} + \\frac{u^t}{\\rho} ||_2^2 \\quad \mathrm{[x-step]}
-
- .. math::
-
- u^{t+1} = u^t + \\rho (x^{t+1} - z^{t+1}) \\quad \mathrm{[u-step]}
-
- During the z-step, the model minimizes the augmented Lagrangian function with respect to z, utilizing
- DL-based denoisers. In the x-step, it optimizes x by minimizing the data consistency term through
- unrolling a gradient descent scheme (DC-GD). The u-step involves updating the Lagrange multiplier u.
- These steps are iterated for a specified number of cycles.
-
- The model includes an initializer for Lagrange multipliers.
-
- It also allows for outputting auxiliary steps.
-
- :class:`VSharpNet` is tailored for 2D MRI data reconstruction.
-
- References
- ----------
- .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction
- of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954.
- """
-
- def __init__(
- self,
- forward_operator: Callable[[tuple[Any, ...]], torch.Tensor],
- backward_operator: Callable[[tuple[Any, ...]], torch.Tensor],
- num_steps: int,
- num_steps_dc_gd: int,
- image_init: InitType = InitType.SENSE,
- no_parameter_sharing: bool = True,
- image_model_architecture: ModelName = ModelName.UNET,
- initializer_channels: tuple[int, ...] = (32, 32, 64, 64),
- initializer_dilations: tuple[int, ...] = (1, 1, 2, 4),
- initializer_multiscale: int = 1,
- initializer_activation: ActivationType = ActivationType.PRELU,
- auxiliary_steps: int = 0,
- **kwargs,
- ) -> None:
- """Inits :class:`VSharpNet`.
-
- Parameters
- ----------
- forward_operator : Callable[[tuple[Any, ...]], torch.Tensor]
- Forward operator function.
- backward_operator : Callable[[tuple[Any, ...]], torch.Tensor]
- Backward operator function.
- num_steps : int
- Number of steps in the ADMM algorithm.
- num_steps_dc_gd : int
- Number of steps in the Data Consistency using Gradient Descent step of ADMM.
- image_init : str
- Image initialization method. Default: 'sense'.
- no_parameter_sharing : bool
- Flag indicating whether parameter sharing is enabled in the denoiser blocks.
- image_model_architecture : ModelName
- Image model architecture. Default: ModelName.UNET.
- initializer_channels : tuple[int, ...]
- Tuple of integers specifying the number of output channels for each convolutional layer in the
- Lagrange multiplier initializer. Default: (32, 32, 64, 64).
- initializer_dilations : tuple[int, ...]
- Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier
- initializer. Default: (1, 1, 2, 4).
- initializer_multiscale : int
- Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1.
- initializer_activation : ActivationType
- Activation type for the Lagrange multiplier initializer. Default: ActivationType.PRELU.
- auxiliary_steps : int
- Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`.
- If -1, it uses all steps. If I, the last I steps will be used.
- **kwargs: Additional keyword arguments.
- Can be `model_name` or `image_model_` where `` represent parameters of the selected
- image model architecture beyond the standard parameters.
- Depending on the `image_model_architecture` chosen, different kwargs will be applicable.
- """
- # pylint: disable=too-many-locals
- super().__init__()
- for extra_key in kwargs:
- if extra_key != "model_name" and not extra_key.startswith("image_"):
- raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
- self.num_steps = num_steps
- self.num_steps_dc_gd = num_steps_dc_gd
-
- self.no_parameter_sharing = no_parameter_sharing
-
- image_model, image_model_kwargs = _get_model_config(
- image_model_architecture,
- in_channels=COMPLEX_SIZE * 3,
- out_channels=COMPLEX_SIZE,
- **{k.replace("image_", ""): v for (k, v) in kwargs.items() if "image_" in k},
- )
-
- self.denoiser_blocks = nn.ModuleList()
- for _ in range(num_steps if self.no_parameter_sharing else 1):
- self.denoiser_blocks.append(image_model(**image_model_kwargs))
-
- self.initializer = LagrangeMultipliersInitializer(
- in_channels=COMPLEX_SIZE,
- out_channels=COMPLEX_SIZE,
- channels=initializer_channels,
- dilations=initializer_dilations,
- multiscale_depth=initializer_multiscale,
- activation=initializer_activation,
- )
-
- self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True))
- nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0)
-
- self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True))
- nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0)
-
- self.forward_operator = forward_operator
- self.backward_operator = backward_operator
-
- if image_init not in [InitType.SENSE, InitType.ZERO_FILLED]:
- raise ValueError(
- f"Unknown image_initialization. Expected `InitType.SENSE` or `InitType.ZERO_FILLED`. "
- f"Got {image_init}."
- )
-
- self.image_init = image_init
-
- if not ((auxiliary_steps == -1) or (0 < auxiliary_steps <= num_steps)):
- raise ValueError(
- f"Number of auxiliary steps should be -1 to use all steps or a positive"
- f" integer <= than `num_steps`. Received {auxiliary_steps}."
- )
- if auxiliary_steps == -1:
- self.auxiliary_steps = list(range(num_steps))
- else:
- self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps))
-
- self._coil_dim = 1
- self._complex_dim = -1
- self._spatial_dims = (2, 3)
-
- def forward(
- self,
- masked_kspace: torch.Tensor,
- sensitivity_map: torch.Tensor,
- sampling_mask: torch.Tensor,
- ) -> list[torch.Tensor]:
- """Computes forward pass of :class:`VSharpNet`.
-
- Parameters
- ----------
- masked_kspace: torch.Tensor
- Masked k-space of shape (N, coil, height, width, complex=2).
- sensitivity_map: torch.Tensor
- Sensitivity map of shape (N, coil, height, width, complex=2).
- sampling_mask: torch.Tensor
- Sampling mask of shape (N, 1, height, width, 1).
-
- Returns
- -------
- out : list of torch.Tensors
- List of output images of shape (N, height, width, complex=2).
- """
- out = []
- if self.image_init == InitType.SENSE:
- x = reduce_operator(
- coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims),
- sensitivity_map=sensitivity_map,
- dim=self._coil_dim,
- )
- else:
- x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim)
-
- z = x.clone()
-
- u = self.initializer(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
-
- for admm_step in range(self.num_steps):
- z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0](
- torch.cat(
- [z, x, u / self.rho[admm_step]],
- dim=self._complex_dim,
- ).permute(0, 3, 1, 2)
- ).permute(0, 2, 3, 1)
-
- for dc_gd_step in range(self.num_steps_dc_gd):
- dc = apply_mask(
- self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims)
- - masked_kspace,
- sampling_mask,
- return_mask=False,
- )
- dc = self.backward_operator(dc, dim=self._spatial_dims)
- dc = reduce_operator(dc, sensitivity_map, self._coil_dim)
-
- x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u)
-
- if admm_step in self.auxiliary_steps:
- out.append(x)
-
- u = u + self.rho[admm_step] * (x - z)
-
- return out
-
-
-class LagrangeMultipliersInitializer3D(torch.nn.Module):
- """A convolutional neural network model that initializes the Lagrange multiplier of :class:`VSharpNet3D`.
-
- This is an extension to 3D data of :class:`LagrangeMultipliersInitializer`.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- channels: tuple[int, ...],
- dilations: tuple[int, ...],
- multiscale_depth: int = 1,
- activation: ActivationType = ActivationType.PRELU,
- ):
- """Initializes LagrangeMultipliersInitializer3D.
-
- Parameters
- ----------
- in_channels : int
- Number of input channels.
- out_channels : int
- Number of output channels.
- channels : tuple of ints
- Tuple of integers specifying the number of output channels for each convolutional layer in the network.
- dilations : tuple of ints
- Tuple of integers specifying the dilation factor for each convolutional layer in the network.
- multiscale_depth : int
- Number of multiscale features to include in the output. Default: 1.
- activation : ActivationType
- Activation function to use on the output. Default: ActivationType.PRELU.
- """
- super().__init__()
-
- # Define convolutional blocks
- self.conv_blocks = nn.ModuleList()
- tch = in_channels
- for curr_channels, curr_dilations in zip(channels, dilations):
- block = nn.Sequential(
- nn.ReplicationPad3d(curr_dilations),
- nn.Conv3d(tch, curr_channels, 3, padding=0, dilation=curr_dilations),
- )
- tch = curr_channels
- self.conv_blocks.append(block)
-
- # Define output block
- tch = np.sum(channels[-multiscale_depth:])
- block = nn.Conv3d(tch, out_channels, 1, padding=0)
- self.out_block = nn.Sequential(block)
-
- self.multiscale_depth = multiscale_depth
- self.activation = _get_relu_activation(activation)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward pass of :class:`LagrangeMultipliersInitializer3D`.
-
- Parameters
- ----------
- x : torch.Tensor
- Input tensor of shape (batch_size, in_channels, z, x, y).
-
- Returns
- -------
- torch.Tensor
- Output tensor of shape (batch_size, out_channels, z, x, y).
- """
-
- features = []
- for block in self.conv_blocks:
- x = F.relu(block(x), inplace=True)
- if self.multiscale_depth > 1:
- features.append(x)
-
- if self.multiscale_depth > 1:
- x = torch.cat(features[-self.multiscale_depth :], dim=1)
-
- return self.activation(self.out_block(x))
-
-
-class VSharpNet3D(nn.Module):
- """VharpNet 3D version using 3D U-Nets as denoisers.
-
- This is an extension to 3D of :class:`VSharpNet`. For the original paper refer to [1]_.
-
- References
- ----------
- .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction
- of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954.
- """
-
- def __init__(
- self,
- forward_operator: Callable[[tuple[Any, ...]], torch.Tensor],
- backward_operator: Callable[[tuple[Any, ...]], torch.Tensor],
- num_steps: int,
- num_steps_dc_gd: int,
- image_init: InitType = InitType.SENSE,
- no_parameter_sharing: bool = True,
- initializer_channels: tuple[int, ...] = (32, 32, 64, 64),
- initializer_dilations: tuple[int, ...] = (1, 1, 2, 4),
- initializer_multiscale: int = 1,
- initializer_activation: ActivationType = ActivationType.PRELU,
- auxiliary_steps: int = -1,
- unet_num_filters: int = 32,
- unet_num_pool_layers: int = 4,
- unet_dropout: float = 0.0,
- unet_norm: bool = False,
- **kwargs,
- ):
- """Inits :class:`VSharpNet3D`.
-
- Parameters
- ----------
- forward_operator : Callable[[tuple[Any, ...]], torch.Tensor]
- Forward operator function.
- backward_operator : Callable[[tuple[Any, ...]], torch.Tensor]
- Backward operator function.
- num_steps : int
- Number of steps in the ADMM algorithm.
- num_steps_dc_gd : int
- Number of steps in the Data Consistency using Gradient Descent step of ADMM.
- image_init : str
- Image initialization method. Default: 'sense'.
- no_parameter_sharing : bool
- Flag indicating whether parameter sharing is enabled in the denoiser blocks.
- initializer_channels : tuple[int, ...]
- Tuple of integers specifying the number of output channels for each convolutional layer in the
- Lagrange multiplier initializer. Default: (32, 32, 64, 64).
- initializer_dilations : tuple[int, ...]
- Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier
- initializer. Default: (1, 1, 2, 4).
- initializer_multiscale : int
- Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1.
- initializer_activation : ActivationType
- Activation type for the Lagrange multiplier initializer. Default: ActivationType.PReLU.
- auxiliary_steps : int
- Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`.
- If -1, it uses all steps. If I, the last I steps will be used.
- unet_num_filters : int
- U-Net denoisers number of output channels of the first convolutional layer. Default: 32.
- unet_num_pool_layers : int
- U-Net denoisers number of down-sampling and up-sampling layers (depth). Default: 4.
- unet_dropout : float
- U-Net denoisers dropout probability. Default: 0.0
- unet_norm : bool
- Whether to use normalized U-Net as denoiser or not. Default: False.
- **kwargs: Additional keyword arguments.
- Can be `model_name`.
- """
- # pylint: disable=too-many-locals
- super().__init__()
- for extra_key in kwargs:
- if extra_key != "model_name":
- raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
- self.num_steps = num_steps
- self.num_steps_dc_gd = num_steps_dc_gd
-
- self.no_parameter_sharing = no_parameter_sharing
-
- self.denoiser_blocks = nn.ModuleList()
- for _ in range(num_steps if self.no_parameter_sharing else 1):
- self.denoiser_blocks.append(
- (UnetModel3d if not unet_norm else NormUnetModel3d)(
- in_channels=COMPLEX_SIZE * 3,
- out_channels=COMPLEX_SIZE,
- num_filters=unet_num_filters,
- num_pool_layers=unet_num_pool_layers,
- dropout_probability=unet_dropout,
- )
- )
-
- self.initializer = LagrangeMultipliersInitializer3D(
- in_channels=COMPLEX_SIZE,
- out_channels=COMPLEX_SIZE,
- channels=initializer_channels,
- dilations=initializer_dilations,
- multiscale_depth=initializer_multiscale,
- activation=initializer_activation,
- )
-
- self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True))
- nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0)
-
- self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True))
- nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0)
-
- self.forward_operator = forward_operator
- self.backward_operator = backward_operator
-
- if image_init not in ["sense", "zero_filled"]:
- raise ValueError(f"Unknown image_initialization. Expected 'sense' or 'zero_filled'. " f"Got {image_init}.")
-
- self.image_init = image_init
-
- if not (auxiliary_steps == -1 or 0 < auxiliary_steps <= num_steps):
- raise ValueError(
- f"Number of auxiliary steps should be -1 to use all steps or a positive"
- f" integer <= than `num_steps`. Received {auxiliary_steps}."
- )
- if auxiliary_steps == -1:
- self.auxiliary_steps = list(range(num_steps))
- else:
- self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps))
-
- self._coil_dim = 1
- self._complex_dim = -1
- self._spatial_dims = (3, 4)
-
- def forward(
- self,
- masked_kspace: torch.Tensor,
- sensitivity_map: torch.Tensor,
- sampling_mask: torch.Tensor,
- ) -> list[torch.Tensor]:
- """Computes forward pass of :class:`VSharpNet3D`.
-
- Parameters
- ----------
- masked_kspace : torch.Tensor
- Masked k-space of shape (N, coil, slice, height, width, complex=2).
- sensitivity_map : torch.Tensor
- Sensitivity map of shape (N, coil, slice, height, width, complex=2).
- sampling_mask : torch.Tensor
- Sampling mask of shape (N, 1, 1 or slice, height, width, 1).
-
- Returns
- -------
- out : list of torch.Tensors
- List of output images each of shape (N, slice, height, width, complex=2).
- """
- out = []
- if self.image_init == InitType.SENSE:
- x = reduce_operator(
- coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims),
- sensitivity_map=sensitivity_map,
- dim=self._coil_dim,
- )
- else:
- x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim)
-
- z = x.clone()
-
- u = self.initializer(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1)
-
- for admm_step in range(self.num_steps):
- z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0](
- torch.cat(
- [z, x, u / self.rho[admm_step]],
- dim=self._complex_dim,
- ).permute(0, 4, 1, 2, 3)
- ).permute(0, 2, 3, 4, 1)
-
- for dc_gd_step in range(self.num_steps_dc_gd):
- dc = apply_mask(
- self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims)
- - masked_kspace,
- sampling_mask,
- return_mask=False,
- )
- dc = self.backward_operator(dc, dim=self._spatial_dims)
- dc = reduce_operator(dc, sensitivity_map, self._coil_dim)
-
- x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u)
-
- if admm_step in self.auxiliary_steps:
- out.append(x)
-
- u = u + self.rho[admm_step] * (x - z)
-
- return out
+# Copyright (c) DIRECT Contributors
+
+"""This module provides the implementation of vSHARP model.
+
+Most specifically, vSHARP is the variable Splitting Half-quadratic ADMM algorithm for Reconstruction
+of inverse-Problems (vSHARPP) model as presented in [1]_.
+
+
+References
+----------
+.. [1] Yiasemis, G., Moriakov, N., Sonke, J.-J., Teuwen, J.: vSHARP: Variable Splitting Half-quadratic ADMM algorithm
+ for reconstruction of inverse-problems. Magnetic Resonance Imaging. 110266 (2024).
+ https://doi.org/10.1016/j.mri.2024.110266.
+
+"""
+
+
+from __future__ import annotations
+
+from typing import Any, Callable
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from direct.constants import COMPLEX_SIZE
+from direct.data.transforms import apply_mask, expand_operator, reduce_operator
+from direct.nn.get_nn_model_config import ModelName, _get_model_config, _get_relu_activation
+from direct.nn.types import ActivationType, DirectEnum, InitType
+from direct.nn.unet.unet_3d import NormUnetModel3d, UnetModel3d
+
+
+class LagrangeMultipliersInitialization(DirectEnum):
+ LEARNED = "learned"
+ ZEROS = "zeros"
+
+
+class LagrangeMultipliersInitializer(nn.Module):
+ """A convolutional neural network model that initializers the Lagrange multiplier of the :class:`vSHARPNet` [1]_.
+
+ More specifically, it produces an initial value for the Lagrange Multiplier based on the zero-filled image:
+
+ .. math::
+
+ u^0 = \mathcal{G}_{\psi}(x^0).
+
+ References
+ ----------
+ .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction
+ of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ channels: tuple[int, ...],
+ dilations: tuple[int, ...],
+ multiscale_depth: int = 1,
+ activation: ActivationType = ActivationType.PRELU,
+ ) -> None:
+ """Inits :class:`LagrangeMultipliersInitializer`.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels.
+ out_channels : int
+ Number of output channels.
+ channels : tuple of ints
+ Tuple of integers specifying the number of output channels for each convolutional layer in the network.
+ dilations : tuple of ints
+ Tuple of integers specifying the dilation factor for each convolutional layer in the network.
+ multiscale_depth : int
+ Number of multiscale features to include in the output. Default: 1.
+ activation : ActivationType
+ Activation function to use on the output. Default: ActivationType.PRELU.
+ """
+ super().__init__()
+
+ # Define convolutional blocks
+ self.conv_blocks = nn.ModuleList()
+ tch = in_channels
+ for curr_channels, curr_dilations in zip(channels, dilations):
+ block = nn.Sequential(
+ nn.ReplicationPad2d(curr_dilations),
+ nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations),
+ )
+ tch = curr_channels
+ self.conv_blocks.append(block)
+
+ # Define output block
+ tch = np.sum(channels[-multiscale_depth:])
+ block = nn.Conv2d(tch, out_channels, 1, padding=0)
+ self.out_block = nn.Sequential(block)
+
+ self.multiscale_depth = multiscale_depth
+
+ self.activation = _get_relu_activation(activation)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass of :class:`LagrangeMultipliersInitializer`.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor of shape (batch_size, in_channels, height, width).
+
+ Returns
+ -------
+ torch.Tensor
+ Output tensor of shape (batch_size, out_channels, height, width).
+ """
+
+ features = []
+ for block in self.conv_blocks:
+ x = F.relu(block(x), inplace=True)
+ if self.multiscale_depth > 1:
+ features.append(x)
+
+ if self.multiscale_depth > 1:
+ x = torch.cat(features[-self.multiscale_depth :], dim=1)
+
+ return self.activation(self.out_block(x))
+
+
+class VSharpNet(nn.Module):
+ """
+ Variable Splitting Half-quadratic ADMM algorithm for Reconstruction of Parallel MRI [1]_.
+
+ Variable Splitting Half Quadratic VSharpNet is a deep learning model that solves
+ the augmented Lagrangian derivation of the variable half quadratic splitting problem
+ using ADMM (Alternating Direction Method of Multipliers). It is specifically designed
+ for solving inverse problems in magnetic resonance imaging (MRI).
+
+ The VSharpNet model incorporates an iterative optimization algorithm that consists of
+ three steps: z-step, x-step, and u-step. These steps are detailed mathematically as follows:
+
+ .. math::
+
+ z^{t+1} = \mathrm{argmin}_{z} \\lambda \mathcal{G}(z) + \\frac{\\rho}{2} || x^{t} - z +
+ \\frac{u^t}{\\rho} ||_2^2 \\quad \mathrm{[z-step]}
+
+ .. math::
+
+ x^{t+1} = \mathrm{argmin}_{x} \\frac{1}{2} || \mathcal{A}_{\mathbf{U},\mathbf{S}}(x) - \\tilde{y} ||_2^2 +
+ \\frac{\\rho}{2} || x - z^{t+1} + \\frac{u^t}{\\rho} ||_2^2 \\quad \mathrm{[x-step]}
+
+ .. math::
+
+ u^{t+1} = u^t + \\rho (x^{t+1} - z^{t+1}) \\quad \mathrm{[u-step]}
+
+ During the z-step, the model minimizes the augmented Lagrangian function with respect to z, utilizing
+ DL-based denoisers. In the x-step, it optimizes x by minimizing the data consistency term through
+ unrolling a gradient descent scheme (DC-GD). The u-step involves updating the Lagrange multiplier u.
+ These steps are iterated for a specified number of cycles.
+
+ The model includes an initializer for Lagrange multipliers.
+
+ It also allows for outputting auxiliary steps.
+
+ :class:`VSharpNet` is tailored for 2D MRI data reconstruction.
+
+ References
+ ----------
+ .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction
+ of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954.
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable[[tuple[Any, ...]], torch.Tensor],
+ backward_operator: Callable[[tuple[Any, ...]], torch.Tensor],
+ num_steps: int,
+ num_steps_dc_gd: int,
+ image_init: InitType = InitType.SENSE,
+ no_parameter_sharing: bool = True,
+ image_model_architecture: ModelName = ModelName.UNET,
+ initializer_channels: tuple[int, ...] = (32, 32, 64, 64),
+ initializer_dilations: tuple[int, ...] = (1, 1, 2, 4),
+ initializer_multiscale: int = 1,
+ initializer_activation: ActivationType = ActivationType.PRELU,
+ auxiliary_steps: int = 0,
+ lagrange_initialization: LagrangeMultipliersInitialization = LagrangeMultipliersInitialization.LEARNED,
+ **kwargs,
+ ) -> None:
+ """Inits :class:`VSharpNet`.
+
+ Parameters
+ ----------
+ forward_operator : Callable[[tuple[Any, ...]], torch.Tensor]
+ Forward operator function.
+ backward_operator : Callable[[tuple[Any, ...]], torch.Tensor]
+ Backward operator function.
+ num_steps : int
+ Number of steps in the ADMM algorithm.
+ num_steps_dc_gd : int
+ Number of steps in the Data Consistency using Gradient Descent step of ADMM.
+ image_init : str
+ Image initialization method. Default: 'sense'.
+ no_parameter_sharing : bool
+ Flag indicating whether parameter sharing is enabled in the denoiser blocks.
+ image_model_architecture : ModelName
+ Image model architecture. Default: ModelName.UNET.
+ initializer_channels : tuple[int, ...]
+ Tuple of integers specifying the number of output channels for each convolutional layer in the
+ Lagrange multiplier initializer. Default: (32, 32, 64, 64).
+ initializer_dilations : tuple[int, ...]
+ Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier
+ initializer. Default: (1, 1, 2, 4).
+ initializer_multiscale : int
+ Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1.
+ initializer_activation : ActivationType
+ Activation type for the Lagrange multiplier initializer. Default: ActivationType.PRELU.
+ auxiliary_steps : int
+ Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`.
+ If -1, it uses all steps. If I, the last I steps will be used.
+ lagrange_initialization : LagrangeMultipliersInitialization
+ Lagrange multiplier initialization method. Can be LagrangeMultipliersInitialization.LEARNED or
+ LagrangeMultipliersInitialization.ZEROS, corresponding to learned initialization or zero initialization.
+ Default: LagrangeMultipliersInitialization.LEARNED.
+ **kwargs: Additional keyword arguments.
+ Can be `model_name` or `image_model_` where `` represent parameters of the selected
+ image model architecture beyond the standard parameters.
+ Depending on the `image_model_architecture` chosen, different kwargs will be applicable.
+ """
+ # pylint: disable=too-many-locals
+ super().__init__()
+ for extra_key in kwargs:
+ if extra_key != "model_name" and not extra_key.startswith("image_"):
+ raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
+ self.num_steps = num_steps
+ self.num_steps_dc_gd = num_steps_dc_gd
+
+ self.no_parameter_sharing = no_parameter_sharing
+
+ image_model, image_model_kwargs = _get_model_config(
+ image_model_architecture,
+ in_channels=COMPLEX_SIZE * 3,
+ out_channels=COMPLEX_SIZE,
+ **{k.replace("image_", ""): v for (k, v) in kwargs.items() if "image_" in k},
+ )
+
+ self.denoiser_blocks = nn.ModuleList()
+ for _ in range(num_steps if self.no_parameter_sharing else 1):
+ self.denoiser_blocks.append(image_model(**image_model_kwargs))
+
+ self.lagrange_initialization = lagrange_initialization
+ if lagrange_initialization == LagrangeMultipliersInitialization.LEARNED:
+ self.initializer = LagrangeMultipliersInitializer(
+ in_channels=COMPLEX_SIZE,
+ out_channels=COMPLEX_SIZE,
+ channels=initializer_channels,
+ dilations=initializer_dilations,
+ multiscale_depth=initializer_multiscale,
+ activation=initializer_activation,
+ )
+
+ self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True))
+ nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0)
+
+ self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True))
+ nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0)
+
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+
+ if image_init not in [InitType.SENSE, InitType.ZERO_FILLED]:
+ raise ValueError(
+ f"Unknown image_initialization. Expected `InitType.SENSE` or `InitType.ZERO_FILLED`. "
+ f"Got {image_init}."
+ )
+
+ self.image_init = image_init
+
+ if not ((auxiliary_steps == -1) or (0 < auxiliary_steps <= num_steps)):
+ raise ValueError(
+ f"Number of auxiliary steps should be -1 to use all steps or a positive"
+ f" integer <= than `num_steps`. Received {auxiliary_steps}."
+ )
+ if auxiliary_steps == -1:
+ self.auxiliary_steps = list(range(num_steps))
+ else:
+ self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps))
+
+ self._coil_dim = 1
+ self._complex_dim = -1
+ self._spatial_dims = (2, 3)
+
+ def forward(
+ self,
+ masked_kspace: torch.Tensor,
+ sensitivity_map: torch.Tensor,
+ sampling_mask: torch.Tensor,
+ ) -> list[torch.Tensor]:
+ """Computes forward pass of :class:`VSharpNet`.
+
+ Parameters
+ ----------
+ masked_kspace: torch.Tensor
+ Masked k-space of shape (N, coil, height, width, complex=2).
+ sensitivity_map: torch.Tensor
+ Sensitivity map of shape (N, coil, height, width, complex=2).
+ sampling_mask: torch.Tensor
+ Sampling mask of shape (N, 1, height, width, 1).
+
+ Returns
+ -------
+ out : list of torch.Tensors
+ List of output images of shape (N, height, width, complex=2).
+ """
+ out = []
+ if self.image_init == InitType.SENSE:
+ x = reduce_operator(
+ coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims),
+ sensitivity_map=sensitivity_map,
+ dim=self._coil_dim,
+ )
+ else:
+ x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim)
+
+ z = x.clone()
+
+ if self.lagrange_initialization == LagrangeMultipliersInitialization.LEARNED:
+ u = self.initializer(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+ else:
+ u = torch.zeros_like(x).to(x.device)
+
+ for admm_step in range(self.num_steps):
+ z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0](
+ torch.cat(
+ [z, x, u / self.rho[admm_step]],
+ dim=self._complex_dim,
+ ).permute(0, 3, 1, 2)
+ ).permute(0, 2, 3, 1)
+
+ for dc_gd_step in range(self.num_steps_dc_gd):
+ dc = apply_mask(
+ self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims)
+ - masked_kspace,
+ sampling_mask,
+ return_mask=False,
+ )
+ dc = self.backward_operator(dc, dim=self._spatial_dims)
+ dc = reduce_operator(dc, sensitivity_map, self._coil_dim)
+
+ x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u)
+
+ if admm_step in self.auxiliary_steps:
+ out.append(x)
+
+ u = u + self.rho[admm_step] * (x - z)
+
+ return out
+
+
+class LagrangeMultipliersInitializer3D(torch.nn.Module):
+ """A convolutional neural network model that initializes the Lagrange multiplier of :class:`VSharpNet3D`.
+
+ This is an extension to 3D data of :class:`LagrangeMultipliersInitializer`.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ channels: tuple[int, ...],
+ dilations: tuple[int, ...],
+ multiscale_depth: int = 1,
+ activation: ActivationType = ActivationType.PRELU,
+ ):
+ """Initializes LagrangeMultipliersInitializer3D.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels.
+ out_channels : int
+ Number of output channels.
+ channels : tuple of ints
+ Tuple of integers specifying the number of output channels for each convolutional layer in the network.
+ dilations : tuple of ints
+ Tuple of integers specifying the dilation factor for each convolutional layer in the network.
+ multiscale_depth : int
+ Number of multiscale features to include in the output. Default: 1.
+ activation : ActivationType
+ Activation function to use on the output. Default: ActivationType.PRELU.
+ """
+ super().__init__()
+
+ # Define convolutional blocks
+ self.conv_blocks = nn.ModuleList()
+ tch = in_channels
+ for curr_channels, curr_dilations in zip(channels, dilations):
+ block = nn.Sequential(
+ nn.ReplicationPad3d(curr_dilations),
+ nn.Conv3d(tch, curr_channels, 3, padding=0, dilation=curr_dilations),
+ )
+ tch = curr_channels
+ self.conv_blocks.append(block)
+
+ # Define output block
+ tch = np.sum(channels[-multiscale_depth:])
+ block = nn.Conv3d(tch, out_channels, 1, padding=0)
+ self.out_block = nn.Sequential(block)
+
+ self.multiscale_depth = multiscale_depth
+ self.activation = _get_relu_activation(activation)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass of :class:`LagrangeMultipliersInitializer3D`.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor of shape (batch_size, in_channels, z, x, y).
+
+ Returns
+ -------
+ torch.Tensor
+ Output tensor of shape (batch_size, out_channels, z, x, y).
+ """
+
+ features = []
+ for block in self.conv_blocks:
+ x = F.relu(block(x), inplace=True)
+ if self.multiscale_depth > 1:
+ features.append(x)
+
+ if self.multiscale_depth > 1:
+ x = torch.cat(features[-self.multiscale_depth :], dim=1)
+
+ return self.activation(self.out_block(x))
+
+
+class VSharpNet3D(nn.Module):
+ """VharpNet 3D version using 3D U-Nets as denoisers.
+
+ This is an extension to 3D of :class:`VSharpNet`. For the original paper refer to [1]_.
+
+ References
+ ----------
+ .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction
+ of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954.
+ """
+
+ def __init__(
+ self,
+ forward_operator: Callable[[tuple[Any, ...]], torch.Tensor],
+ backward_operator: Callable[[tuple[Any, ...]], torch.Tensor],
+ num_steps: int,
+ num_steps_dc_gd: int,
+ image_init: InitType = InitType.SENSE,
+ no_parameter_sharing: bool = True,
+ initializer_channels: tuple[int, ...] = (32, 32, 64, 64),
+ initializer_dilations: tuple[int, ...] = (1, 1, 2, 4),
+ initializer_multiscale: int = 1,
+ initializer_activation: ActivationType = ActivationType.PRELU,
+ auxiliary_steps: int = -1,
+ lagrange_initialization: LagrangeMultipliersInitialization = LagrangeMultipliersInitialization.LEARNED,
+ unet_num_filters: int = 32,
+ unet_num_pool_layers: int = 4,
+ unet_dropout: float = 0.0,
+ unet_norm: bool = False,
+ **kwargs,
+ ):
+ """Inits :class:`VSharpNet3D`.
+
+ Parameters
+ ----------
+ forward_operator : Callable[[tuple[Any, ...]], torch.Tensor]
+ Forward operator function.
+ backward_operator : Callable[[tuple[Any, ...]], torch.Tensor]
+ Backward operator function.
+ num_steps : int
+ Number of steps in the ADMM algorithm.
+ num_steps_dc_gd : int
+ Number of steps in the Data Consistency using Gradient Descent step of ADMM.
+ image_init : str
+ Image initialization method. Default: 'sense'.
+ no_parameter_sharing : bool
+ Flag indicating whether parameter sharing is enabled in the denoiser blocks.
+ initializer_channels : tuple[int, ...]
+ Tuple of integers specifying the number of output channels for each convolutional layer in the
+ Lagrange multiplier initializer. Default: (32, 32, 64, 64).
+ initializer_dilations : tuple[int, ...]
+ Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier
+ initializer. Default: (1, 1, 2, 4).
+ initializer_multiscale : int
+ Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1.
+ initializer_activation : ActivationType
+ Activation type for the Lagrange multiplier initializer. Default: ActivationType.PReLU.
+ auxiliary_steps : int
+ Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`.
+ If -1, it uses all steps. If I, the last I steps will be used.
+ lagrange_initialization : LagrangeMultipliersInitialization
+ Lagrange multiplier initialization method. Can be LagrangeMultipliersInitialization.LEARNED or
+ LagrangeMultipliersInitialization.ZEROS, corresponding to learned initialization or zero initialization.
+ Default: LagrangeMultipliersInitialization.LEARNED.
+ unet_num_filters : int
+ U-Net denoisers number of output channels of the first convolutional layer. Default: 32.
+ unet_num_pool_layers : int
+ U-Net denoisers number of down-sampling and up-sampling layers (depth). Default: 4.
+ unet_dropout : float
+ U-Net denoisers dropout probability. Default: 0.0
+ unet_norm : bool
+ Whether to use normalized U-Net as denoiser or not. Default: False.
+ **kwargs: Additional keyword arguments.
+ Can be `model_name`.
+ """
+ # pylint: disable=too-many-locals
+ super().__init__()
+ for extra_key in kwargs:
+ if extra_key != "model_name":
+ raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
+ self.num_steps = num_steps
+ self.num_steps_dc_gd = num_steps_dc_gd
+
+ self.no_parameter_sharing = no_parameter_sharing
+
+ self.denoiser_blocks = nn.ModuleList()
+ for _ in range(num_steps if self.no_parameter_sharing else 1):
+ self.denoiser_blocks.append(
+ (UnetModel3d if not unet_norm else NormUnetModel3d)(
+ in_channels=COMPLEX_SIZE * 3,
+ out_channels=COMPLEX_SIZE,
+ num_filters=unet_num_filters,
+ num_pool_layers=unet_num_pool_layers,
+ dropout_probability=unet_dropout,
+ )
+ )
+
+ self.lagrange_initialization = lagrange_initialization
+ if lagrange_initialization == LagrangeMultipliersInitialization.LEARNED:
+ self.initializer = LagrangeMultipliersInitializer3D(
+ in_channels=COMPLEX_SIZE,
+ out_channels=COMPLEX_SIZE,
+ channels=initializer_channels,
+ dilations=initializer_dilations,
+ multiscale_depth=initializer_multiscale,
+ activation=initializer_activation,
+ )
+
+ self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True))
+ nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0)
+
+ self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True))
+ nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0)
+
+ self.forward_operator = forward_operator
+ self.backward_operator = backward_operator
+
+ if image_init not in [InitType.SENSE, InitType.ZERO_FILLED]:
+ raise ValueError(
+ f"Unknown image_initialization. Expected `InitType.SENSE` or `InitType.ZERO_FILLED`. "
+ f"Got {image_init}."
+ )
+
+ self.image_init = image_init
+
+ if not (auxiliary_steps == -1 or 0 < auxiliary_steps <= num_steps):
+ raise ValueError(
+ f"Number of auxiliary steps should be -1 to use all steps or a positive"
+ f" integer <= than `num_steps`. Received {auxiliary_steps}."
+ )
+ if auxiliary_steps == -1:
+ self.auxiliary_steps = list(range(num_steps))
+ else:
+ self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps))
+
+ self._coil_dim = 1
+ self._complex_dim = -1
+ self._spatial_dims = (3, 4)
+
+ def forward(
+ self,
+ masked_kspace: torch.Tensor,
+ sensitivity_map: torch.Tensor,
+ sampling_mask: torch.Tensor,
+ ) -> list[torch.Tensor]:
+ """Computes forward pass of :class:`VSharpNet3D`.
+
+ Parameters
+ ----------
+ masked_kspace : torch.Tensor
+ Masked k-space of shape (N, coil, slice, height, width, complex=2).
+ sensitivity_map : torch.Tensor
+ Sensitivity map of shape (N, coil, slice, height, width, complex=2).
+ sampling_mask : torch.Tensor
+ Sampling mask of shape (N, 1, 1 or slice, height, width, 1).
+
+ Returns
+ -------
+ out : list of torch.Tensors
+ List of output images each of shape (N, slice, height, width, complex=2).
+ """
+ out = []
+ if self.image_init == InitType.SENSE:
+ x = reduce_operator(
+ coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims),
+ sensitivity_map=sensitivity_map,
+ dim=self._coil_dim,
+ )
+ else:
+ x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim)
+
+ z = x.clone()
+
+ if self.lagrange_initialization == LagrangeMultipliersInitialization.LEARNED:
+ u = self.initializer(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1)
+ else:
+ u = torch.zeros_like(x).to(x.device)
+
+ for admm_step in range(self.num_steps):
+ z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0](
+ torch.cat(
+ [z, x, u / self.rho[admm_step]],
+ dim=self._complex_dim,
+ ).permute(0, 4, 1, 2, 3)
+ ).permute(0, 2, 3, 4, 1)
+
+ for dc_gd_step in range(self.num_steps_dc_gd):
+ dc = apply_mask(
+ self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims)
+ - masked_kspace,
+ sampling_mask,
+ return_mask=False,
+ )
+ dc = self.backward_operator(dc, dim=self._spatial_dims)
+ dc = reduce_operator(dc, sensitivity_map, self._coil_dim)
+
+ x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u)
+
+ if admm_step in self.auxiliary_steps:
+ out.append(x)
+
+ u = u + self.rho[admm_step] * (x - z)
+
+ return out
diff --git a/direct/nn/vsharp/vsharp_engine.py b/direct/nn/vsharp/vsharp_engine.py
index 9efb16e4..51578f93 100644
--- a/direct/nn/vsharp/vsharp_engine.py
+++ b/direct/nn/vsharp/vsharp_engine.py
@@ -6,9 +6,9 @@
References
----------
-.. [1] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and
- Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023).
- https://doi.org/10.48550/arXiv.2311.15856.
+.. [1] Yiasemis, G., Moriakov, N., Sonke, J.-J., Teuwen, J.: vSHARP: Variable Splitting Half-quadratic ADMM algorithm
+ for reconstruction of inverse-problems. Magnetic Resonance Imaging. 110266 (2024).
+ https://doi.org/10.1016/j.mri.2024.110266.
.. [2] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and
Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023).
https://doi.org/10.48550/arXiv.2311.15856.
@@ -165,13 +165,21 @@ def _do_iteration(
if shape == registered_image.shape
else data["reference_image"].tile((1, registered_image.shape[1], *([1] * len(shape[1:]))))
),
+ weight=self.cfg.additional_models.registration_model.reg_loss_factor,
)
+
+ if "displacement_field" in data:
+ target_displacement_field = data["displacement_field"]
+ else:
+ target_displacement_field = None
+
loss_dict = self.compute_loss_on_data(
loss_dict,
loss_fns,
data,
output_displacement_field=displacement_field,
- target_displacement_field=data["displacement_field"],
+ target_displacement_field=target_displacement_field,
+ weight=self.cfg.additional_models.registration_model.reg_loss_factor,
)
loss = sum(loss_dict.values()) # type: ignore
@@ -183,7 +191,11 @@ def _do_iteration(
output_image = output_images[-1]
return DoIterationOutput(
- output_image=(output_image, registered_image) if "registration_model" in self.models else output_image,
+ output_image=(
+ (output_image, registered_image, displacement_field)
+ if "registration_model" in self.models
+ else output_image
+ ),
sensitivity_map=data["sensitivity_map"],
sampling_mask=data["sampling_mask"],
data_dict={**loss_dict},
diff --git a/direct/train.py b/direct/train.py
index 31207ba0..b3400c1e 100644
--- a/direct/train.py
+++ b/direct/train.py
@@ -1,332 +1,332 @@
-# coding=utf-8
-# Copyright (c) DIRECT Contributors
-import argparse
-import functools
-import logging
-import os
-import pathlib
-import sys
-import urllib.parse
-from collections import defaultdict
-from typing import Callable, Dict, List, Optional, Union
-
-import numpy as np
-import torch
-from omegaconf import DictConfig
-
-from direct.cli.utils import check_train_val
-from direct.common.subsample import build_masking_function
-from direct.data.datasets import build_dataset_from_input
-from direct.data.lr_scheduler import WarmupMultiStepLR
-from direct.data.mri_transforms import build_mri_transforms
-from direct.environment import setup_training_environment
-from direct.launch import launch
-from direct.types import PathOrString
-from direct.utils import dict_flatten, remove_keys, set_all_seeds, str_to_class
-from direct.utils.dataset import get_filenames_for_datasets_from_config
-from direct.utils.io import check_is_valid_url, read_json
-
-logger = logging.getLogger(__name__)
-
-
-def parse_noise_dict(noise_dict: dict, percentile: float = 1.0, multiplier: float = 1.0):
- logger.info("Parsing noise dictionary...")
- output: Dict = defaultdict(dict)
- for filename in noise_dict:
- data_per_volume = noise_dict[filename]
- for slice_no in data_per_volume:
- curr_data = data_per_volume[slice_no]
- if percentile != 1.0:
- lower_clip = np.percentile(curr_data, 100 * (1 - percentile))
- upper_clip = np.percentile(curr_data, 100 * percentile)
- curr_data = np.clip(curr_data, lower_clip, upper_clip)
-
- output[filename][int(slice_no)] = (
- curr_data * multiplier
- ) ** 2 # np.asarray(curr_data) * multiplier# (np.clip(curr_data, lower_clip, upper_clip) * multiplier) ** 2
-
- return output
-
-
-def get_root_of_file(filename: PathOrString):
- """Get the root directory of the file or URL to file.
-
- Examples
- --------
- >>> get_root_of_file('/mnt/archive/data.txt')
- >>> /mnt/archive
- >>> get_root_of_file('https://aiforoncology.nl/people')
- >>> https://aiforoncology.nl/
-
- Parameters
- ----------
- filename: pathlib.Path or str
-
- Returns
- -------
- pathlib.Path or str
- """
- if check_is_valid_url(str(filename)):
- filename = urllib.parse.urljoin(str(filename), ".")
- else:
- filename = pathlib.Path(filename).parents[0]
-
- return filename
-
-
-def build_transforms_from_environment(env, dataset_config: DictConfig) -> Callable:
- masking = dataset_config.transforms.masking # Masking func can be None
- mask_func = None if masking is None else build_masking_function(**masking)
- mri_transforms_func = functools.partial(
- build_mri_transforms,
- forward_operator=env.engine.forward_operator,
- backward_operator=env.engine.backward_operator,
- mask_func=mask_func,
- )
- return mri_transforms_func(**dict_flatten(dict(remove_keys(dataset_config.transforms, "masking")))) # type: ignore
-
-
-def build_training_datasets_from_environment(
- env,
- datasets_config: List[DictConfig],
- lists_root: Optional[PathOrString] = None,
- data_root: Optional[PathOrString] = None,
- initial_images: Optional[Union[List[pathlib.Path], None]] = None,
- initial_kspaces: Optional[Union[List[pathlib.Path], None]] = None,
- pass_text_description: bool = True,
- pass_dictionaries: Optional[Dict[str, Dict]] = None,
-):
- datasets = []
- for idx, dataset_config in enumerate(datasets_config):
- if pass_text_description:
- if not "text_description" in dataset_config:
- dataset_config.text_description = f"ds{idx}" if len(datasets_config) > 1 else None
- else:
- dataset_config.text_description = None
- if dataset_config.transforms.masking is None: # type: ignore
- logger.info(
- "Masking function set to None for %s.",
- dataset_config.text_description, # type: ignore
- )
- transforms = build_transforms_from_environment(env, dataset_config)
- dataset_args = {"transforms": transforms, "dataset_config": dataset_config}
- if initial_images is not None:
- dataset_args.update({"initial_images": initial_images})
- if initial_kspaces is not None:
- dataset_args.update({"initial_kspaces": initial_kspaces})
- if data_root is not None:
- dataset_args.update({"data_root": data_root})
- filenames_filter = get_filenames_for_datasets_from_config(dataset_config, lists_root, data_root)
- dataset_args.update({"filenames_filter": filenames_filter})
- if pass_dictionaries is not None:
- dataset_args.update({"pass_dictionaries": pass_dictionaries})
- dataset = build_dataset_from_input(**dataset_args)
-
- logger.debug("Transforms %s / %s :\n%s", idx + 1, len(datasets_config), transforms)
- datasets.append(dataset)
- logger.info(
- "Data size for %s (%s/%s): %s.",
- dataset_config.text_description, # type: ignore
- idx + 1,
- len(datasets_config),
- len(dataset),
- )
-
- return datasets
-
-
-def setup_train(
- run_name: str,
- training_root: Union[pathlib.Path, None],
- validation_root: Union[pathlib.Path, None],
- base_directory: pathlib.Path,
- cfg_filename: PathOrString,
- force_validation: bool,
- initialization_checkpoint: PathOrString,
- initial_images: Optional[Union[List[pathlib.Path], None]],
- initial_kspace: Optional[Union[List[pathlib.Path], None]],
- noise: Optional[Union[List[pathlib.Path], None]],
- device: str,
- num_workers: int,
- resume: bool,
- machine_rank: int,
- mixed_precision: bool,
- debug: bool,
-):
- env = setup_training_environment(
- run_name,
- base_directory,
- cfg_filename,
- device,
- machine_rank,
- mixed_precision,
- debug=debug,
- )
-
- # Trigger cudnn benchmark and remove the associated cache
- torch.backends.cudnn.benchmark = True
- torch.cuda.empty_cache()
-
- if initial_kspace is not None and initial_images is not None:
- raise ValueError("Cannot both provide initial kspace or initial images.")
- # Create training data
- training_dataset_args = {"env": env, "datasets_config": env.cfg.training.datasets, "pass_text_description": True}
- pass_dictionaries = {}
- if noise is not None:
- if not env.cfg.physics.use_noise_matrix:
- raise ValueError("cfg.physics.use_noise_matrix is null, yet command line passed noise files.")
-
- noise = [read_json(fn) for fn in noise]
- pass_dictionaries["loglikelihood_scaling"] = [
- parse_noise_dict(_, percentile=0.999, multiplier=env.cfg.physics.noise_matrix_scaling) for _ in noise
- ]
- training_dataset_args.update({"pass_dictionaries": pass_dictionaries})
-
- if training_root is not None:
- training_dataset_args.update({"data_root": training_root})
- # Get the lists_root. Assume now the given path is with respect to the config file.
- lists_root = get_root_of_file(cfg_filename)
- if lists_root is not None:
- training_dataset_args.update({"lists_root": lists_root})
- if initial_images is not None:
- training_dataset_args.update({"initial_images": initial_images[0]})
- if initial_kspace is not None:
- training_dataset_args.update({"initial_kspaces": initial_kspace[0]})
-
- # Build training datasets
- training_datasets = build_training_datasets_from_environment(**training_dataset_args)
- training_data_sizes = [len(_) for _ in training_datasets]
- logger.info("Training data sizes: %s (sum=%s).", training_data_sizes, sum(training_data_sizes))
-
- # Create validation data
- if "validation" in env.cfg:
- validation_dataset_args = {
- "env": env,
- "datasets_config": env.cfg.validation.datasets,
- "pass_text_description": True,
- }
- if validation_root is not None:
- validation_dataset_args.update({"data_root": validation_root})
- lists_root = get_root_of_file(cfg_filename)
- if lists_root is not None:
- validation_dataset_args.update({"lists_root": lists_root})
- if initial_images is not None:
- validation_dataset_args.update({"initial_images": initial_images[1]})
- if initial_kspace is not None:
- validation_dataset_args.update({"initial_kspaces": initial_kspace[1]})
-
- # Build validation datasets
- validation_data = build_training_datasets_from_environment(**validation_dataset_args)
- else:
- logger.info("No validation data.")
- validation_data = None
-
- # Create the optimizers
- logger.info("Building optimizers.")
- optimizer_params = [{"params": env.engine.model.parameters()}]
- for curr_model_name in env.engine.models:
- # TODO(jt): Can get learning rate from the config per additional model too.
- curr_learning_rate = env.cfg.training.lr
- logger.info("Adding model parameters of %s with learning rate %s.", curr_model_name, curr_learning_rate)
- optimizer_params.append(
- {
- "params": env.engine.models[curr_model_name].parameters(),
- "lr": curr_learning_rate,
- }
- )
-
- optimizer: torch.optim.Optimizer = str_to_class("torch.optim", env.cfg.training.optimizer)( # noqa
- optimizer_params,
- lr=env.cfg.training.lr,
- weight_decay=env.cfg.training.weight_decay,
- ) # noqa
-
- # Build the LR scheduler, we use a fixed LR schedule step size, no adaptive training schedule.
- solver_steps = list(
- range(
- env.cfg.training.lr_step_size,
- env.cfg.training.num_iterations,
- env.cfg.training.lr_step_size,
- )
- )
- lr_scheduler = WarmupMultiStepLR(
- optimizer,
- solver_steps,
- env.cfg.training.lr_gamma,
- warmup_factor=1 / 3.0,
- warmup_iterations=env.cfg.training.lr_warmup_iter,
- warmup_method="linear",
- )
-
- # Just to make sure.
- torch.cuda.empty_cache()
-
- # Check the initialization checkpoint
- if env.cfg.training.model_checkpoint:
- if initialization_checkpoint:
- logger.warning(
- "`--initialization-checkpoint is set, and config has a set `training.model_checkpoint`: %s. "
- "Will overwrite config variable with the command line: %s.",
- env.cfg.training.model_checkpoint,
- initialization_checkpoint,
- )
- # Now overwrite this in the configuration, so the correct value is dumped.
- env.cfg.training.model_checkpoint = str(initialization_checkpoint)
- else:
- initialization_checkpoint = env.cfg.training.model_checkpoint
-
- env.engine.train(
- optimizer,
- lr_scheduler,
- training_datasets,
- env.experiment_dir,
- validation_datasets=validation_data,
- resume=resume,
- initialization=initialization_checkpoint,
- start_with_validation=force_validation,
- num_workers=num_workers,
- )
-
-
-def train_from_argparse(args: argparse.Namespace):
- # This sets MKL threads to 1.
- # DataLoader can otherwise bring a lot of difficulties when computing CPU FFTs in the transforms.
- torch.set_num_threads(1)
- os.environ["OMP_NUM_THREADS"] = "1"
- # Disable Tensorboard warnings.
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
-
- if args.initialization_images is not None and args.initialization_kspace is not None:
- sys.exit("--initialization-images and --initialization-kspace are mutually exclusive.")
- check_train_val(args.initialization_images, "initialization-images")
- check_train_val(args.initialization_kspace, "initialization-kspace")
- check_train_val(args.noise, "noise")
-
- set_all_seeds(args.seed)
-
- run_name = args.name if args.name is not None else os.path.basename(args.cfg_file)[:-5]
-
- # TODO(jt): Duplicate params
- launch(
- setup_train,
- args.num_machines,
- args.num_gpus,
- args.machine_rank,
- args.dist_url,
- run_name,
- args.training_root,
- args.validation_root,
- args.experiment_dir,
- args.cfg_file,
- args.force_validation,
- args.initialization_checkpoint,
- args.initialization_images,
- args.initialization_kspace,
- args.noise,
- args.device,
- args.num_workers,
- args.resume,
- args.machine_rank,
- args.mixed_precision,
- args.debug,
- )
+# coding=utf-8
+# Copyright (c) DIRECT Contributors
+import argparse
+import functools
+import logging
+import os
+import pathlib
+import sys
+import urllib.parse
+from collections import defaultdict
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from omegaconf import DictConfig
+
+from direct.cli.utils import check_train_val
+from direct.common.subsample import build_masking_function
+from direct.data.datasets import build_dataset_from_input
+from direct.data.lr_scheduler import WarmupMultiStepLR
+from direct.data.mri_transforms import build_mri_transforms
+from direct.environment import setup_training_environment
+from direct.launch import launch
+from direct.types import PathOrString
+from direct.utils import dict_flatten, remove_keys, set_all_seeds, str_to_class
+from direct.utils.dataset import get_filenames_for_datasets_from_config
+from direct.utils.io import check_is_valid_url, read_json
+
+logger = logging.getLogger(__name__)
+
+
+def parse_noise_dict(noise_dict: dict, percentile: float = 1.0, multiplier: float = 1.0):
+ logger.info("Parsing noise dictionary...")
+ output: Dict = defaultdict(dict)
+ for filename in noise_dict:
+ data_per_volume = noise_dict[filename]
+ for slice_no in data_per_volume:
+ curr_data = data_per_volume[slice_no]
+ if percentile != 1.0:
+ lower_clip = np.percentile(curr_data, 100 * (1 - percentile))
+ upper_clip = np.percentile(curr_data, 100 * percentile)
+ curr_data = np.clip(curr_data, lower_clip, upper_clip)
+
+ output[filename][int(slice_no)] = (
+ curr_data * multiplier
+ ) ** 2 # np.asarray(curr_data) * multiplier# (np.clip(curr_data, lower_clip, upper_clip) * multiplier) ** 2
+
+ return output
+
+
+def get_root_of_file(filename: PathOrString):
+ """Get the root directory of the file or URL to file.
+
+ Examples
+ --------
+ >>> get_root_of_file('/mnt/archive/data.txt')
+ >>> /mnt/archive
+ >>> get_root_of_file('https://aiforoncology.nl/people')
+ >>> https://aiforoncology.nl/
+
+ Parameters
+ ----------
+ filename: pathlib.Path or str
+
+ Returns
+ -------
+ pathlib.Path or str
+ """
+ if check_is_valid_url(str(filename)):
+ filename = urllib.parse.urljoin(str(filename), ".")
+ else:
+ filename = pathlib.Path(filename).parents[0]
+
+ return filename
+
+
+def build_transforms_from_environment(env, dataset_config: DictConfig) -> Callable:
+ masking = dataset_config.transforms.masking # Masking func can be None
+ mask_func = None if masking is None else build_masking_function(**masking)
+ mri_transforms_func = functools.partial(
+ build_mri_transforms,
+ forward_operator=env.engine.forward_operator,
+ backward_operator=env.engine.backward_operator,
+ mask_func=mask_func,
+ )
+ return mri_transforms_func(**dict_flatten(dict(remove_keys(dataset_config.transforms, "masking")))) # type: ignore
+
+
+def build_training_datasets_from_environment(
+ env,
+ datasets_config: List[DictConfig],
+ lists_root: Optional[PathOrString] = None,
+ data_root: Optional[PathOrString] = None,
+ initial_images: Optional[Union[List[pathlib.Path], None]] = None,
+ initial_kspaces: Optional[Union[List[pathlib.Path], None]] = None,
+ pass_text_description: bool = True,
+ pass_dictionaries: Optional[Dict[str, Dict]] = None,
+):
+ datasets = []
+ for idx, dataset_config in enumerate(datasets_config):
+ if pass_text_description:
+ if not "text_description" in dataset_config:
+ dataset_config.text_description = f"ds{idx}" if len(datasets_config) > 1 else None
+ else:
+ dataset_config.text_description = None
+ if dataset_config.transforms.masking is None: # type: ignore
+ logger.info(
+ "Masking function set to None for %s.",
+ dataset_config.text_description, # type: ignore
+ )
+ transforms = build_transforms_from_environment(env, dataset_config)
+ dataset_args = {"transforms": transforms, "dataset_config": dataset_config}
+ if initial_images is not None:
+ dataset_args.update({"initial_images": initial_images})
+ if initial_kspaces is not None:
+ dataset_args.update({"initial_kspaces": initial_kspaces})
+ if data_root is not None:
+ dataset_args.update({"data_root": data_root})
+ filenames_filter = get_filenames_for_datasets_from_config(dataset_config, lists_root, data_root)
+ dataset_args.update({"filenames_filter": filenames_filter})
+ if pass_dictionaries is not None:
+ dataset_args.update({"pass_dictionaries": pass_dictionaries})
+ dataset = build_dataset_from_input(**dataset_args)
+
+ logger.debug("Transforms %s / %s :\n%s", idx + 1, len(datasets_config), transforms)
+ datasets.append(dataset)
+ logger.info(
+ "Data size for %s (%s/%s): %s.",
+ dataset_config.text_description, # type: ignore
+ idx + 1,
+ len(datasets_config),
+ len(dataset),
+ )
+
+ return datasets
+
+
+def setup_train(
+ run_name: str,
+ training_root: Union[pathlib.Path, None],
+ validation_root: Union[pathlib.Path, None],
+ base_directory: pathlib.Path,
+ cfg_filename: PathOrString,
+ force_validation: bool,
+ initialization_checkpoint: PathOrString,
+ initial_images: Optional[Union[List[pathlib.Path], None]],
+ initial_kspace: Optional[Union[List[pathlib.Path], None]],
+ noise: Optional[Union[List[pathlib.Path], None]],
+ device: str,
+ num_workers: int,
+ resume: bool,
+ machine_rank: int,
+ mixed_precision: bool,
+ debug: bool,
+):
+ env = setup_training_environment(
+ run_name,
+ base_directory,
+ cfg_filename,
+ device,
+ machine_rank,
+ mixed_precision,
+ debug=debug,
+ )
+
+ # Trigger cudnn benchmark and remove the associated cache
+ torch.backends.cudnn.benchmark = True
+ torch.cuda.empty_cache()
+
+ if initial_kspace is not None and initial_images is not None:
+ raise ValueError("Cannot both provide initial kspace or initial images.")
+ # Create training data
+ training_dataset_args = {"env": env, "datasets_config": env.cfg.training.datasets, "pass_text_description": True}
+ pass_dictionaries = {}
+ if noise is not None:
+ if not env.cfg.physics.use_noise_matrix:
+ raise ValueError("cfg.physics.use_noise_matrix is null, yet command line passed noise files.")
+
+ noise = [read_json(fn) for fn in noise]
+ pass_dictionaries["loglikelihood_scaling"] = [
+ parse_noise_dict(_, percentile=0.999, multiplier=env.cfg.physics.noise_matrix_scaling) for _ in noise
+ ]
+ training_dataset_args.update({"pass_dictionaries": pass_dictionaries})
+
+ if training_root is not None:
+ training_dataset_args.update({"data_root": training_root})
+ # Get the lists_root. Assume now the given path is with respect to the config file.
+ lists_root = get_root_of_file(cfg_filename)
+ if lists_root is not None:
+ training_dataset_args.update({"lists_root": lists_root})
+ if initial_images is not None:
+ training_dataset_args.update({"initial_images": initial_images[0]})
+ if initial_kspace is not None:
+ training_dataset_args.update({"initial_kspaces": initial_kspace[0]})
+
+ # Build training datasets
+ training_datasets = build_training_datasets_from_environment(**training_dataset_args)
+ training_data_sizes = [len(_) for _ in training_datasets]
+ logger.info("Training data sizes: %s (sum=%s).", training_data_sizes, sum(training_data_sizes))
+
+ # Create validation data
+ if "validation" in env.cfg:
+ validation_dataset_args = {
+ "env": env,
+ "datasets_config": env.cfg.validation.datasets,
+ "pass_text_description": True,
+ }
+ if validation_root is not None:
+ validation_dataset_args.update({"data_root": validation_root})
+ lists_root = get_root_of_file(cfg_filename)
+ if lists_root is not None:
+ validation_dataset_args.update({"lists_root": lists_root})
+ if initial_images is not None:
+ validation_dataset_args.update({"initial_images": initial_images[1]})
+ if initial_kspace is not None:
+ validation_dataset_args.update({"initial_kspaces": initial_kspace[1]})
+
+ # Build validation datasets
+ validation_data = build_training_datasets_from_environment(**validation_dataset_args)
+ else:
+ logger.info("No validation data.")
+ validation_data = None
+
+ # Create the optimizers
+ logger.info("Building optimizers.")
+ optimizer_params = [{"params": env.engine.model.parameters()}]
+ for curr_model_name in env.engine.models:
+ # TODO(jt): Can get learning rate from the config per additional model too.
+ curr_learning_rate = env.cfg.training.lr
+ logger.info("Adding model parameters of %s with learning rate %s.", curr_model_name, curr_learning_rate)
+ optimizer_params.append(
+ {
+ "params": env.engine.models[curr_model_name].parameters(),
+ "lr": curr_learning_rate,
+ }
+ )
+
+ optimizer: torch.optim.Optimizer = str_to_class("torch.optim", env.cfg.training.optimizer)( # noqa
+ optimizer_params,
+ lr=env.cfg.training.lr,
+ weight_decay=env.cfg.training.weight_decay,
+ ) # noqa
+
+ # Build the LR scheduler, we use a fixed LR schedule step size, no adaptive training schedule.
+ solver_steps = list(
+ range(
+ env.cfg.training.lr_step_size,
+ env.cfg.training.num_iterations,
+ env.cfg.training.lr_step_size,
+ )
+ )
+ lr_scheduler = WarmupMultiStepLR(
+ optimizer,
+ solver_steps,
+ env.cfg.training.lr_gamma,
+ warmup_factor=1 / 3.0,
+ warmup_iterations=env.cfg.training.lr_warmup_iter,
+ warmup_method="linear",
+ )
+
+ # Just to make sure.
+ torch.cuda.empty_cache()
+
+ # Check the initialization checkpoint
+ if env.cfg.training.model_checkpoint:
+ if initialization_checkpoint:
+ logger.warning(
+ "`--initialization-checkpoint is set, and config has a set `training.model_checkpoint`: %s. "
+ "Will overwrite config variable with the command line: %s.",
+ env.cfg.training.model_checkpoint,
+ initialization_checkpoint,
+ )
+ # Now overwrite this in the configuration, so the correct value is dumped.
+ env.cfg.training.model_checkpoint = str(initialization_checkpoint)
+ else:
+ initialization_checkpoint = env.cfg.training.model_checkpoint
+
+ env.engine.train(
+ optimizer,
+ lr_scheduler,
+ training_datasets,
+ env.experiment_dir,
+ validation_datasets=validation_data,
+ resume=resume,
+ initialization=initialization_checkpoint,
+ start_with_validation=force_validation,
+ num_workers=num_workers,
+ )
+
+
+def train_from_argparse(args: argparse.Namespace):
+ # This sets MKL threads to 1.
+ # DataLoader can otherwise bring a lot of difficulties when computing CPU FFTs in the transforms.
+ torch.set_num_threads(1)
+ os.environ["OMP_NUM_THREADS"] = "1"
+ # Disable Tensorboard warnings.
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
+
+ if args.initialization_images is not None and args.initialization_kspace is not None:
+ sys.exit("--initialization-images and --initialization-kspace are mutually exclusive.")
+ check_train_val(args.initialization_images, "initialization-images")
+ check_train_val(args.initialization_kspace, "initialization-kspace")
+ check_train_val(args.noise, "noise")
+
+ set_all_seeds(args.seed)
+
+ run_name = args.name if args.name is not None else os.path.basename(args.cfg_file)[:-5]
+
+ # TODO(jt): Duplicate params
+ launch(
+ setup_train,
+ args.num_machines,
+ args.num_gpus,
+ args.machine_rank,
+ args.dist_url,
+ run_name,
+ args.training_root,
+ args.validation_root,
+ args.experiment_dir,
+ args.cfg_file,
+ args.force_validation,
+ args.initialization_checkpoint,
+ args.initialization_images,
+ args.initialization_kspace,
+ args.noise,
+ args.device,
+ args.num_workers,
+ args.resume,
+ args.machine_rank,
+ args.mixed_precision,
+ args.debug,
+ )
diff --git a/direct/types.py b/direct/types.py
index a10d0d78..16040d7c 100644
--- a/direct/types.py
+++ b/direct/types.py
@@ -1,163 +1,166 @@
-# Copyright (c) DIRECT Contributors
-
-"""direct.types module."""
-
-from __future__ import annotations
-
-import pathlib
-from enum import Enum
-from typing import NewType, Union
-
-import numpy as np
-import torch
-from omegaconf.omegaconf import DictConfig
-from torch import nn as nn
-from torch.cuda.amp import GradScaler
-
-DictOrDictConfig = Union[dict, DictConfig]
-Number = Union[float, int]
-PathOrString = Union[pathlib.Path, str]
-FileOrUrl = NewType("FileOrUrl", PathOrString)
-HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler]
-TensorOrNone = Union[None, torch.Tensor]
-TensorOrNdarray = Union[torch.Tensor, np.ndarray]
-
-
-class DirectEnum(str, Enum):
- """Type of any enumerator with allowed comparison to string invariant to cases."""
-
- @classmethod
- def from_str(cls, value: str) -> DirectEnum | None:
- statuses = cls.__members__.keys()
- for st in statuses:
- if st.lower() == value.lower():
- return cls[st]
- return None
-
- def __eq__(self, other: object) -> bool:
- if isinstance(other, Enum):
- _other = str(other.value)
- else:
- _other = str(other)
- return bool(self.value.lower() == _other.lower())
-
- def __hash__(self) -> int:
- # re-enable hashtable so it can be used as a dict key or in a set
- return hash(self.value.lower())
-
-
-class KspaceKey(DirectEnum):
- ACS_KSPACE = "acs_kspace"
- KSPACE = "kspace"
- MASKED_KSPACE = "masked_kspace"
- REFERENCE_KSPACE = "reference_kspace"
-
-
-class TransformKey(DirectEnum):
- # K-space keys
- ACS_KSPACE = "acs_kspace"
- KSPACE = "kspace"
- MASKED_KSPACE = "masked_kspace"
- # Mask keys
- SAMPLING_MASK = "sampling_mask"
- ACS_MASK = "acs_mask"
- PADDING = "padding"
- # Image keys
- TARGET = "target"
- # Other keys
- SENSITIVITY_MAP = "sensitivity_map"
- SCALING_FACTOR = "scaling_factor"
- # Registration keys
- DISPLACEMENT_FIELD = "displacement_field"
- REFERENCE_IMAGE = "reference_image"
- REFERENCE_KSPACE = "reference_kspace"
- MOVING_IMAGE = "moving_image"
- WARPED_IMAGE = "warped_image"
-
-
-class MaskFuncMode(DirectEnum):
- STATIC = "static"
- DYNAMIC = "dynamic"
- MULTISLICE = "multislice"
-
-
-class IntegerListOrTupleStringMeta(type):
- """Metaclass for the :class:`IntegerListOrTupleString` class.
-
- Returns
- -------
- bool
- True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.
- """
-
- def __instancecheck__(cls, instance):
- """Check if the given instance is a valid representation of an IntegerListOrTupleString.
-
- Parameters
- ----------
- cls : type
- The class being checked, i.e., IntegerListOrTupleStringMeta.
- instance : object
- The instance being checked.
-
- Returns
- -------
- bool
- True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.
- """
- if isinstance(instance, str):
- try:
- assert (instance.startswith("[") and instance.endswith("]")) or (
- instance.startswith("(") and instance.endswith(")")
- )
- elements = instance.strip()[1:-1].split(",")
- integers = [int(element) for element in elements]
- return all(isinstance(num, int) for num in integers)
- except (AssertionError, ValueError, AttributeError):
- pass
- return False
-
-
-class IntegerListOrTupleString(metaclass=IntegerListOrTupleStringMeta):
- """IntegerListOrTupleString class represents a list or tuple of integers based on a string representation.
-
- Examples
- --------
- s1 = "[1, 2, 45, -1, 0]"
- print(isinstance(s1, IntegerListOrTupleString)) # True
- print(IntegerListOrTupleString(s1)) # [1, 2, 45, -1, 0]
- print(type(IntegerListOrTupleString(s1))) #
- print(type(IntegerListOrTupleString(s1)[0])) #
-
- s2 = "(10, -9, 20)"
- print(isinstance(s2, IntegerListOrTupleString)) # True
- print(IntegerListOrTupleString(s2)) # (10, -9, 20)
- print(type(IntegerListOrTupleString(s2))) #
- print(type(IntegerListOrTupleString(s2)[0])) #
-
- s3 = "[a, we, 2]"
- print(isinstance(s3, IntegerListOrTupleString)) # False
-
- s4 = "(1, 2, 3]"
- print(isinstance(s4 IntegerListOrTupleString)) # False
- """
-
- def __new__(cls, string):
- """
- Create a new instance of IntegerListOrTupleString based on the given string representation.
-
- Parameters
- ----------
- string : str
- The string representation of the integer list or tuple.
-
- Returns
- -------
- list or tuple
- A new instance of IntegerListOrTupleString.
- """
- list_or_tuple = list if string.startswith("[") else tuple
- string = string.strip()[1:-1] # Remove outer brackets
- elements = string.split(",")
- integers = [int(element) for element in elements]
- return list_or_tuple(integers)
+# Copyright (c) DIRECT Contributors
+
+"""direct.types module."""
+
+from __future__ import annotations
+
+import pathlib
+from enum import Enum
+from typing import NewType, Union
+
+import numpy as np
+import torch
+from omegaconf.omegaconf import DictConfig
+from torch import nn as nn
+from torch.cuda.amp import GradScaler
+
+DictOrDictConfig = Union[dict, DictConfig]
+Number = Union[float, int]
+PathOrString = Union[pathlib.Path, str]
+FileOrUrl = NewType("FileOrUrl", PathOrString)
+HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler]
+TensorOrNone = Union[None, torch.Tensor]
+TensorOrNdarray = Union[torch.Tensor, np.ndarray]
+
+
+class DirectEnum(str, Enum):
+ """Type of any enumerator with allowed comparison to string invariant to cases."""
+
+ @classmethod
+ def from_str(cls, value: str) -> DirectEnum | None:
+ statuses = cls.__members__.keys()
+ for st in statuses:
+ if st.lower() == value.lower():
+ return cls[st]
+ return None
+
+ def __eq__(self, other: object) -> bool:
+ if isinstance(other, Enum):
+ _other = str(other.value)
+ else:
+ _other = str(other)
+ return bool(self.value.lower() == _other.lower())
+
+ def __hash__(self) -> int:
+ # re-enable hashtable so it can be used as a dict key or in a set
+ return hash(self.value.lower())
+
+
+class KspaceKey(DirectEnum):
+ ACS_KSPACE = "acs_kspace"
+ KSPACE = "kspace"
+ MASKED_KSPACE = "masked_kspace"
+ REFERENCE_KSPACE = "reference_kspace"
+
+
+class TransformKey(DirectEnum):
+ # K-space keys
+ ACS_KSPACE = "acs_kspace"
+ KSPACE = "kspace"
+ MASKED_KSPACE = "masked_kspace"
+ # Mask keys
+ SAMPLING_MASK = "sampling_mask"
+ ACS_MASK = "acs_mask"
+ PADDING = "padding"
+ # Image keys
+ TARGET = "target"
+ # Other keys
+ SENSITIVITY_MAP = "sensitivity_map"
+ SCALING_FACTOR = "scaling_factor"
+ # Registration keys
+ DISPLACEMENT_FIELD = "displacement_field"
+ REFERENCE_IMAGE = "reference_image"
+ REFERENCE_KSPACE = "reference_kspace"
+ MOVING_IMAGE = "moving_image"
+ WARPED_IMAGE = "warped_image"
+ # Other keys
+ ACCELERATION = "acceleration"
+ CENTER_FRACTION = "center_fraction"
+
+
+class MaskFuncMode(DirectEnum):
+ STATIC = "static"
+ DYNAMIC = "dynamic"
+ MULTISLICE = "multislice"
+
+
+class IntegerListOrTupleStringMeta(type):
+ """Metaclass for the :class:`IntegerListOrTupleString` class.
+
+ Returns
+ -------
+ bool
+ True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.
+ """
+
+ def __instancecheck__(cls, instance):
+ """Check if the given instance is a valid representation of an IntegerListOrTupleString.
+
+ Parameters
+ ----------
+ cls : type
+ The class being checked, i.e., IntegerListOrTupleStringMeta.
+ instance : object
+ The instance being checked.
+
+ Returns
+ -------
+ bool
+ True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.
+ """
+ if isinstance(instance, str):
+ try:
+ assert (instance.startswith("[") and instance.endswith("]")) or (
+ instance.startswith("(") and instance.endswith(")")
+ )
+ elements = instance.strip()[1:-1].split(",")
+ integers = [int(element) for element in elements]
+ return all(isinstance(num, int) for num in integers)
+ except (AssertionError, ValueError, AttributeError):
+ pass
+ return False
+
+
+class IntegerListOrTupleString(metaclass=IntegerListOrTupleStringMeta):
+ """IntegerListOrTupleString class represents a list or tuple of integers based on a string representation.
+
+ Examples
+ --------
+ s1 = "[1, 2, 45, -1, 0]"
+ print(isinstance(s1, IntegerListOrTupleString)) # True
+ print(IntegerListOrTupleString(s1)) # [1, 2, 45, -1, 0]
+ print(type(IntegerListOrTupleString(s1))) #
+ print(type(IntegerListOrTupleString(s1)[0])) #
+
+ s2 = "(10, -9, 20)"
+ print(isinstance(s2, IntegerListOrTupleString)) # True
+ print(IntegerListOrTupleString(s2)) # (10, -9, 20)
+ print(type(IntegerListOrTupleString(s2))) #
+ print(type(IntegerListOrTupleString(s2)[0])) #
+
+ s3 = "[a, we, 2]"
+ print(isinstance(s3, IntegerListOrTupleString)) # False
+
+ s4 = "(1, 2, 3]"
+ print(isinstance(s4 IntegerListOrTupleString)) # False
+ """
+
+ def __new__(cls, string):
+ """
+ Create a new instance of IntegerListOrTupleString based on the given string representation.
+
+ Parameters
+ ----------
+ string : str
+ The string representation of the integer list or tuple.
+
+ Returns
+ -------
+ list or tuple
+ A new instance of IntegerListOrTupleString.
+ """
+ list_or_tuple = list if string.startswith("[") else tuple
+ string = string.strip()[1:-1] # Remove outer brackets
+ elements = string.split(",")
+ integers = [int(element) for element in elements]
+ return list_or_tuple(integers)
diff --git a/direct/utils/writers.py b/direct/utils/writers.py
index d5720c10..d1fcd42f 100644
--- a/direct/utils/writers.py
+++ b/direct/utils/writers.py
@@ -48,13 +48,22 @@ def write_output_to_h5(
with open(output_directory / "metrics_inference.json", "w") as f:
f.write(json.dumps(metrics, indent=4))
- for idx, (volume, sampling_mask, filename) in enumerate(output[0]):
+ for idx, (data, sampling_mask, filename) in enumerate(output[0]):
if isinstance(filename, pathlib.PosixPath):
filename = filename.name
logger.info(f"({idx + 1}/{len(output[0])}): Writing {output_directory / filename}...")
+ if isinstance(data, tuple):
+ volume, registration_volume, displacement_field = data
+ else:
+ volume = data
+ registration_volume = None
+
reconstruction = volume.numpy()[:, 0, ...].astype(np.float32)
+ if registration_volume is not None:
+ registration_volume = registration_volume.numpy()[:, 0, ...].astype(np.float32)
+ displacement_field = displacement_field.numpy().astype(np.float32)
if sampling_mask is not None:
sampling_mask = sampling_mask.numpy()[:, 0, ...].astype(np.float32)
@@ -66,3 +75,6 @@ def write_output_to_h5(
f.create_dataset(output_key, data=reconstruction)
if sampling_mask is not None:
f.create_dataset("sampling_mask", data=sampling_mask)
+ if registration_volume is not None:
+ f.create_dataset("registration_volume", data=registration_volume)
+ f.create_dataset("displacement_field", data=displacement_field)
\ No newline at end of file