diff --git a/direct/common/subsample_config.py b/direct/common/subsample_config.py index 89cccbb3..2640c6d4 100644 --- a/direct/common/subsample_config.py +++ b/direct/common/subsample_config.py @@ -1,23 +1,23 @@ -# Copyright (c) DIRECT Contributors - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Optional - -from omegaconf import MISSING - -from direct.config.defaults import BaseConfig -from direct.types import MaskFuncMode - - -@dataclass -class MaskingConfig(BaseConfig): - name: str = MISSING - accelerations: tuple[float, ...] = (5.0,) - center_fractions: Optional[tuple[float, ...]] = (0.1,) - uniform_range: bool = False - mode: MaskFuncMode = MaskFuncMode.STATIC - - val_accelerations: tuple[float, ...] = (5.0, 10.0) - val_center_fractions: Optional[tuple[float, ...]] = (0.1, 0.05) +# Copyright (c) DIRECT Contributors + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from omegaconf import MISSING + +from direct.config import BaseConfig +from direct.types import MaskFuncMode + + +@dataclass +class MaskingConfig(BaseConfig): + name: str = MISSING + accelerations: tuple[float, ...] = (5.0,) + center_fractions: Optional[tuple[float, ...]] = (0.1,) + uniform_range: bool = False + mode: MaskFuncMode = MaskFuncMode.STATIC + + val_accelerations: tuple[float, ...] = (5.0, 10.0) + val_center_fractions: Optional[tuple[float, ...]] = (0.1, 0.05) diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index ee915dec..68dc4487 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -1,2793 +1,2795 @@ -# Copyright (c) DIRECT Contributors - -"""The `direct.data.mri_transforms` module contains mri transformations utilized to transform or augment k-space data, -used for DIRECT's training pipeline. They can be also used individually by importing them into python scripts.""" - -from __future__ import annotations - -import contextlib -import copy -import functools -import logging -import random -import warnings -from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union - -import numpy as np -import torch - -from direct.algorithms.mri_algorithms import EspiritCalibration -from direct.data import transforms as T -from direct.exceptions import ItemNotFoundException -from direct.registration.elastic_deformation import RandomElasticDeformationModule -from direct.registration.registration import DemonsFilterType, DisplacementModule, DisplacementTransformType -from direct.ssl.ssl import ( - GaussianMaskSplitterModule, - HalfMaskSplitterModule, - HalfSplitType, - MaskSplitterType, - SSLTransformMaskPrefixes, - UniformMaskSplitterModule, -) -from direct.types import DirectEnum, IntegerListOrTupleString, KspaceKey, TransformKey -from direct.utils import DirectModule, DirectTransform -from direct.utils.asserts import assert_complex - -logger = logging.getLogger(__name__) - - -@contextlib.contextmanager -def temp_seed(rng, seed): - state = rng.get_state() - rng.seed(seed) - try: - yield - finally: - rng.set_state(state) - - -class Compose(DirectTransform): - """Compose several transformations together, for instance ClipAndScale and a flip. - - Code based on torchvision: https://github.com/pytorch/vision, but got forked from there as torchvision has some - additional dependencies. - """ - - def __init__(self, transforms: Iterable[Callable]) -> None: - """Inits :class:`Compose`. - - Parameters - ---------- - transforms: Iterable[Callable] - List of transforms. - """ - super().__init__() - self.transforms = transforms - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Calls :class:`Compose`. - - Parameters - ---------- - sample: dict[str, Any] - Dict sample. - - Returns - ------- - dict[str, Any] - Dict sample transformed by `transforms`. - """ - for transform in self.transforms: - sample = transform(sample) - - return sample - - def __repr__(self): - """Representation of :class:`Compose`.""" - repr_string = self.__class__.__name__ + "(" - for transform in self.transforms: - repr_string += "\n" - repr_string += f" {transform}," - repr_string = repr_string[:-1] + "\n)" - return repr_string - - -class RandomRotation(DirectTransform): - r"""Random :math:`k`-space rotation. - - Performs a random rotation with probability :math:`p`. Rotation degrees must be multiples of 90. - """ - - def __init__( - self, - degrees: Sequence[int] = (-90, 90), - p: float = 0.5, - keys_to_rotate: tuple[TransformKey, ...] = (TransformKey.KSPACE,), - ) -> None: - r"""Inits :class:`RandomRotation`. - - Parameters - ---------- - degrees: sequence of ints - Degrees of rotation. Must be a multiple of 90. If len(degrees) > 1, then a degree will be chosen at random. - Default: (-90, 90). - p: float - Probability of rotation. Default: 0.5 - keys_to_rotate : tuple of TransformKeys - Keys to rotate. Default: "kspace". - """ - super().__init__() - - assert all(degree % 90 == 0 for degree in degrees) - - self.degrees = degrees - self.p = p - self.keys_to_rotate = keys_to_rotate - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Calls :class:`RandomRotation`. - - Parameters - ---------- - sample: dict[str, Any] - Dict sample. - - Returns - ------- - dict[str, Any] - Sample with rotated values of `keys_to_rotate`. - """ - if random.SystemRandom().random() <= self.p: - degree = random.SystemRandom().choice(self.degrees) - k = degree // 90 - for key in self.keys_to_rotate: - if key in sample: - value = T.view_as_complex(sample[key].clone()) - sample[key] = T.view_as_real(torch.rot90(value, k=k, dims=(-2, -1))) - - # If rotated by multiples of (n + 1) * 90 degrees, reconstruction size also needs to change - reconstruction_size = sample.get("reconstruction_size", None) - if reconstruction_size and (k % 2) == 1: - sample["reconstruction_size"] = ( - reconstruction_size[:-3] + reconstruction_size[-3:-1][::-1] + reconstruction_size[-1:] - ) - - return sample - - -class RandomFlipType(DirectEnum): - HORIZONTAL = "horizontal" - VERTICAL = "vertical" - RANDOM = "random" - BOTH = "both" - - -class RandomFlip(DirectTransform): - r"""Random k-space flip transform. - - Performs a random flip with probability :math:`p`. Flip can be horizontal, vertical, or a random choice of the two. - """ - - def __init__( - self, - flip: RandomFlipType = RandomFlipType.RANDOM, - p: float = 0.5, - keys_to_flip: tuple[TransformKey, ...] = (TransformKey.KSPACE,), - ) -> None: - r"""Inits :class:`RandomFlip`. - - Parameters - ---------- - flip : RandomFlipType - Horizontal, vertical, or random choice of the two. Default: RandomFlipType.RANDOM. - p : float - Probability of flip. Default: 0.5 - keys_to_flip : tuple of TransformKeys - Keys to flip. Default: "kspace". - """ - super().__init__() - - self.flip = flip - self.p = p - self.keys_to_flip = keys_to_flip - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Calls :class:`RandomFlip`. - - Parameters - ---------- - sample: dict[str, Any] - Dict sample. - - Returns - ------- - dict[str, Any] - Sample with flipped values of `keys_to_flip`. - """ - if random.SystemRandom().random() <= self.p: - dims = ( - (-2,) - if self.flip == "horizontal" - else ( - (-1,) - if self.flip == "vertical" - else (-2, -1) if self.flip == "both" else (random.SystemRandom().choice([-2, -1]),) - ) - ) - - for key in self.keys_to_flip: - if key in sample: - value = T.view_as_complex(sample[key].clone()) - value = torch.flip(value, dims=dims) - sample[key] = T.view_as_real(value) - - return sample - - -class RandomReverse(DirectTransform): - r"""Random reverse of the order along a given dimension of a PyTorch tensor.""" - - def __init__( - self, - dim: int = 1, - p: float = 0.5, - keys_to_reverse: tuple[TransformKey, ...] = (TransformKey.KSPACE,), - ) -> None: - r"""Inits :class:`RandomReverse`. - - Parameters - ---------- - dim : int - Dimension along to perform reversion. Typically, this is for time or slice dimension. Default: 2. - p : float - Probability of flip. Default: 0.5 - keys_to_reverse : tuple of TransformKeys - Keys to reverse. Default: "kspace". - """ - super().__init__() - - self.dim = dim - self.p = p - self.keys_to_reverse = keys_to_reverse - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Calls :class:`RandomReverse`. - - Parameters - ---------- - sample: dict[str, Any] - Dict sample. - - Returns - ------- - dict[str, Any] - Sample with flipped values of `keys_to_flip`. - """ - if random.SystemRandom().random() <= self.p: - dim = self.dim - for key in self.keys_to_reverse: - if key in sample: - tensor = sample[key].clone() - - if dim < 0: - dim += tensor.dim() - - tensor = T.view_as_complex(tensor) - - index = [slice(None)] * tensor.dim() - index[dim] = torch.arange(tensor.size(dim) - 1, -1, -1, dtype=torch.long) - - tensor = tensor[tuple(index)] - - sample[key] = T.view_as_real(tensor) - - return sample - - -class CreateSamplingMask(DirectTransform): - """Data Transformer for training MRI reconstruction models. - - Creates sampling mask. - """ - - def __init__( - self, - mask_func: Callable, - shape: Optional[tuple[int, ...]] = None, - use_seed: bool = True, - return_acs: bool = False, - ) -> None: - """Inits :class:`CreateSamplingMask`. - - Parameters - ---------- - mask_func: Callable - A function which creates a sampling mask of the appropriate shape. - shape: tuple, optional - Sampling mask shape. Default: None. - use_seed: bool - If true, a pseudo-random number based on the filename is computed so that every slice of the volume get - the same mask every time. Default: True. - return_acs: bool - If True, it will generate an ACS mask. Default: False. - """ - super().__init__() - self.mask_func = mask_func - self.shape = shape - self.use_seed = use_seed - self.return_acs = return_acs - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Calls :class:`CreateSamplingMask`. - - Parameters - ---------- - sample: dict[str, Any] - Dict sample. - - Returns - ------- - dict[str, Any] - Sample with `sampling_mask` key. - """ - if not self.shape: - shape = sample["kspace"].shape[1:] - elif any(_ is None for _ in self.shape): # Allow None as values. - kspace_shape = list(sample["kspace"].shape[1:-1]) - shape = tuple(_ if _ else kspace_shape[idx] for idx, _ in enumerate(self.shape)) + (2,) - else: - shape = self.shape + (2,) - - seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"]))) - - sampling_mask = self.mask_func(shape=shape, seed=seed, return_acs=False) - - if sampling_mask.ndim == 5: - acceleration = [ - np.prod(sampling_mask[0, _].shape) / sampling_mask[0, _].sum() for _ in range(sampling_mask.shape[1]) - ] - sample["acceleration"] = torch.tensor(acceleration, dtype=torch.float32).unsqueeze(0) - else: - sample["acceleration"] = (np.prod(sampling_mask.shape) / sampling_mask.sum()).unsqueeze(0) - - if "padding" in sample: - sampling_mask = T.apply_padding(sampling_mask, sample["padding"]) - - # Shape 3D: (1, 1, height, width, 1), 2D: (1, height, width, 1) - sample["sampling_mask"] = sampling_mask - - if self.return_acs: - sample["acs_mask"] = self.mask_func(shape=shape, seed=seed, return_acs=True) - if sampling_mask.ndim == 5: - center_fraction = [ - sample["acs_mask"][0, _].sum() / np.prod(sample["acs_mask"][0, _].shape) - for _ in range(sample["acs_mask"].shape[1]) - ] - sample["center_fraction"] = torch.tensor(center_fraction, dtype=torch.float32).unsqueeze(0) - else: - sample["center_fraction"] = (sample["acs_mask"].sum() / np.prod(sample["acs_mask"].shape)).unsqueeze(0) - return sample - - -class ApplyMaskModule(DirectModule): - """Data Transformer for training MRI reconstruction models. - - Masks the input k-space (with key `input_kspace_key`) using a sampling mask with key `sampling_mask_key` onto - a new masked k-space with key `target_kspace_key`. - """ - - def __init__( - self, - sampling_mask_key: str = "sampling_mask", - input_kspace_key: KspaceKey = KspaceKey.KSPACE, - target_kspace_key: KspaceKey = KspaceKey.MASKED_KSPACE, - ) -> None: - """Inits :class:`ApplyMaskModule`. - - Parameters - ---------- - sampling_mask_key: str - Default: "sampling_mask". - input_kspace_key: KspaceKey - Default: KspaceKey.KSPACE. - target_kspace_key: KspaceKey - Default KspaceKey.MASKED_KSPACE. - """ - super().__init__() - self.logger = logging.getLogger(type(self).__name__) - - self.sampling_mask_key = sampling_mask_key - self.input_kspace_key = input_kspace_key - self.target_kspace_key = target_kspace_key - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`ApplyMaskModule`. - - Applies mask with key `sampling_mask_key` onto kspace `input_kspace_key`. Result is stored as a tensor with - key `target_kspace_key`. - - Parameters - ---------- - sample: dict[str, Any] - Dict sample containing keys `sampling_mask_key` and `input_kspace_key`. - - Returns - ------- - dict[str, Any] - Sample with (new) key `target_kspace_key`. - """ - if self.input_kspace_key not in sample: - raise ValueError(f"Key {self.input_kspace_key} corresponding to `input_kspace_key` not found in sample.") - input_kspace = sample[self.input_kspace_key] - - if self.sampling_mask_key not in sample: - raise ValueError(f"Key {self.sampling_mask_key} corresponding to `sampling_mask_key` not found in sample.") - sampling_mask = sample[self.sampling_mask_key] - - target_kspace, _ = T.apply_mask(input_kspace, sampling_mask) - sample[self.target_kspace_key] = target_kspace - return sample - - -class CropKspace(DirectTransform): - """Data Transformer for training MRI reconstruction models. - - Crops the k-space by: - * It first projects the k-space to the image-domain via the backward operator, - * It crops the back-projected k-space to specified shape or key, - * It transforms the cropped back-projected k-space to the k-space domain via the forward operator. - """ - - def __init__( - self, - crop: Union[str, tuple[int, ...], list[int]], - forward_operator: Callable = T.fft2, - backward_operator: Callable = T.ifft2, - image_space_center_crop: bool = False, - random_crop_sampler_type: Optional[str] = "uniform", - random_crop_sampler_use_seed: Optional[bool] = True, - random_crop_sampler_gaussian_sigma: Optional[list[float]] = None, - ) -> None: - """Inits :class:`CropKspace`. - - Parameters - ---------- - crop: tuple of ints or str - Shape to crop the input to or a string pointing to a crop key (e.g. `reconstruction_size`). - forward_operator: Callable - The forward operator, e.g. some form of FFT (centered or uncentered). - Default: :class:`direct.data.transforms.fft2`. - backward_operator: Callable - The backward operator, e.g. some form of inverse FFT (centered or uncentered). - Default: :class:`direct.data.transforms.ifft2`. - image_space_center_crop: bool - If set, the crop in the data will be taken in the center - random_crop_sampler_type: Optional[str] - If "uniform" the random cropping will be done by uniformly sampling `crop`, as opposed to `gaussian` which - will sample from a gaussian distribution. If `image_space_center_crop` is True, then this is ignored. - Default: "uniform". - random_crop_sampler_use_seed: bool - If true, a pseudo-random number based on the filename is computed so that every slice of the volume - is cropped the same way. Default: True. - random_crop_sampler_gaussian_sigma: Optional[list[float]] - Standard variance of the gaussian when `random_crop_sampler_type` is `gaussian`. - If `image_space_center_crop` is True, then this is ignored. Default: None. - """ - super().__init__() - self.logger = logging.getLogger(type(self).__name__) - - self.image_space_center_crop = image_space_center_crop - - if not (isinstance(crop, (Iterable, str))): - raise ValueError( - f"Invalid input for `crop`. Received {crop}. Can be a list of tuple of integers or a string." - ) - self.crop = crop - - if image_space_center_crop: - self.crop_func = T.complex_center_crop - else: - self.crop_func = functools.partial( - T.complex_random_crop, - sampler=random_crop_sampler_type, - sigma=random_crop_sampler_gaussian_sigma, - ) - self.random_crop_sampler_use_seed = random_crop_sampler_use_seed - - self.forward_operator = forward_operator - self.backward_operator = backward_operator - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Calls :class:`CropKspace`. - - Parameters - ---------- - sample: dict[str, Any] - Dict sample containing key `kspace`. - - Returns - ------- - dict[str, Any] - Cropped and masked sample. - """ - - kspace = sample["kspace"] # shape (coil, [slice/time], height, width, complex=2) - - dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D - - backprojected_kspace = self.backward_operator(kspace, dim=dim) # shape (coil, height, width, complex=2) - - if isinstance(self.crop, IntegerListOrTupleString): - crop_shape = IntegerListOrTupleString(self.crop) - elif isinstance(self.crop, str): - assert self.crop in sample, f"Not found {self.crop} key in sample." - crop_shape = sample[self.crop][:-1] - else: - if kspace.ndim == 5 and len(self.crop) == 2: - crop_shape = (kspace.shape[1],) + tuple(self.crop) - else: - crop_shape = tuple(self.crop) - - cropper_args = { - "data_list": [backprojected_kspace], - "crop_shape": crop_shape, - "contiguous": False, - } - if not self.image_space_center_crop: - cropper_args["seed"] = ( - None if not self.random_crop_sampler_use_seed else tuple(map(ord, str(sample["filename"]))) - ) - cropped_backprojected_kspace = self.crop_func(**cropper_args) - - if "sampling_mask" in sample: - sample["sampling_mask"] = T.complex_center_crop( - sample["sampling_mask"], (1,) + tuple(crop_shape)[1:] if kspace.ndim == 5 else crop_shape - ) - sample["acs_mask"] = T.complex_center_crop( - sample["acs_mask"], (1,) + tuple(crop_shape)[1:] if kspace.ndim == 5 else crop_shape - ) - - # Compute new k-space for the cropped_backprojected_kspace - # shape (coil, [slice/time], new_height, new_width, complex=2) - sample["kspace"] = self.forward_operator(cropped_backprojected_kspace, dim=dim) # The cropped kspace - - return sample - - -class RescaleMode(DirectEnum): - AREA = "area" - BICUBIC = "bicubic" - BILINEAR = "bilinear" - NEAREST = "nearest" - NEAREST_EXACT = "nearest-exact" - TRILINEAR = "trilinear" - - -class RescaleKspace(DirectTransform): - """Rescale k-space (downsample/upsample) module. - - Rescales the k-space: - * It first projects the k-space to the image-domain via the backward operator, - * It rescales the back-projected k-space to specified shape, - * It transforms the rescaled back-projected k-space to the k-space domain via the forward operator. - - Parameters - ---------- - shape : tuple or list of ints - Shape to rescale the input. Must be correspond to (height, width). - forward_operator : Callable - The forward operator, e.g. some form of FFT (centered or uncentered). - Default: :class:`direct.data.transforms.fft2`. - backward_operator : Callable - The backward operator, e.g. some form of inverse FFT (centered or uncentered). - Default: :class:`direct.data.transforms.ifft2`. - rescale_mode : RescaleMode - Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, - RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are - supported for 2D or 3D data. Default: RescaleMode.NEAREST. - kspace_key : KspaceKey - K-space key. Default: KspaceKey.KSPACE. - rescale_2d_if_3d : bool, optional - If True and input k-space data is 3D, rescaling will be done only on the height and width dimensions. - Default: False. - - Note - ---- - If the input k-space data is 3D, rescaling will be done only on the height and width dimensions if - `rescale_2d_if_3d` is set to True. - """ - - def __init__( - self, - shape: Union[tuple[int, int], list[int]], - forward_operator: Callable = T.fft2, - backward_operator: Callable = T.ifft2, - rescale_mode: RescaleMode = RescaleMode.NEAREST, - kspace_key: KspaceKey = KspaceKey.KSPACE, - rescale_2d_if_3d: Optional[bool] = None, - ) -> None: - """Inits :class:`RescaleKspace`. - - Parameters - ---------- - shape : tuple or list of ints - Shape to rescale the input. Must be correspond to (height, width). - forward_operator : Callable - The forward operator, e.g. some form of FFT (centered or uncentered). - Default: :class:`direct.data.transforms.fft2`. - backward_operator : Callable - The backward operator, e.g. some form of inverse FFT (centered or uncentered). - Default: :class:`direct.data.transforms.ifft2`. - rescale_mode : RescaleMode - Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, - RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are - supported for 2D or 3D data. Default: RescaleMode.NEAREST. - kspace_key : KspaceKey - K-space key. Default: KspaceKey.KSPACE. - rescale_2d_if_3d : bool, optional - If True and input k-space data is 3D, rescaling will be done only on the height and width dimensions, - by combining the slice/time dimension with the batch dimension. - Default: False. - """ - super().__init__() - self.logger = logging.getLogger(type(self).__name__) - - if len(shape) not in [2, 3]: - raise ValueError( - f"Shape should be a list or tuple of two integers if 2D or three integers if 3D. " - f"Received: {shape}." - ) - self.shape = shape - self.forward_operator = forward_operator - self.backward_operator = backward_operator - self.rescale_mode = rescale_mode - self.kspace_key = kspace_key - - self.rescale_2d_if_3d = rescale_2d_if_3d - if rescale_2d_if_3d and len(shape) == 3: - raise ValueError("Shape cannot have a length of 3 when rescale_2d_if_3d is set to True.") - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Calls :class:`RescaleKspace`. - - Parameters - ---------- - sample: Dict[str, Any] - Dict sample containing key `kspace`. - - Returns - ------- - Dict[str, Any] - Cropped and masked sample. - """ - kspace = sample[self.kspace_key] # shape (coil, [slice/time], height, width, complex=2) - - dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D - - backprojected_kspace = self.backward_operator(kspace, dim=dim) - - if kspace.ndim == 5 and self.rescale_2d_if_3d: - backprojected_kspace = backprojected_kspace.permute(1, 0, 2, 3, 4) - - if (kspace.ndim == 4) or (kspace.ndim == 5 and not self.rescale_2d_if_3d): - backprojected_kspace = backprojected_kspace.unsqueeze(0) - - rescaled_backprojected_kspace = T.complex_image_resize(backprojected_kspace, self.shape, self.rescale_mode) - - if (kspace.ndim == 4) or (kspace.ndim == 5 and not self.rescale_2d_if_3d): - rescaled_backprojected_kspace = rescaled_backprojected_kspace.squeeze(0) - - if kspace.ndim == 5 and self.rescale_2d_if_3d: - rescaled_backprojected_kspace = rescaled_backprojected_kspace.permute(1, 0, 2, 3, 4) - - # Compute new k-space from rescaled_backprojected_kspace - # shape (coil, [slice/time if rescale_2d_if_3d else new_slc_or_time], new_height, new_width, complex=2) - sample[self.kspace_key] = self.forward_operator(rescaled_backprojected_kspace, dim=dim) # The rescaled kspace - - return sample - - -class PadKspace(DirectTransform): - """Pad k-space with zeros to desired shape module. - - Rescales the k-space by: - * It first projects the k-space to the image-domain via the backward operator, - * It pads the back-projected k-space to specified shape, - * It transforms the rescaled back-projected k-space to the k-space domain via the forward operator. - - Parameters - ---------- - pad_shape : tuple or list of ints - Shape to zero-pad the input. Must be correspond to (height, width) or (slice/time, height, width). - forward_operator : Callable - The forward operator, e.g. some form of FFT (centered or uncentered). - Default: :class:`direct.data.transforms.fft2`. - backward_operator : Callable - The backward operator, e.g. some form of inverse FFT (centered or uncentered). - Default: :class:`direct.data.transforms.ifft2`. - kspace_key : KspaceKey - K-space key. Default: KspaceKey.KSPACE. - """ - - def __init__( - self, - pad_shape: Union[tuple[int, ...], list[int]], - forward_operator: Callable = T.fft2, - backward_operator: Callable = T.ifft2, - kspace_key: KspaceKey = KspaceKey.KSPACE, - ) -> None: - """Inits :class:`RescaleKspace`. - - Parameters - ---------- - pad_shape : tuple or list of ints - Shape to zero-pad the input. Must be correspond to (height, width) or (slice/time, height, width). - forward_operator : Callable - The forward operator, e.g. some form of FFT (centered or uncentered). - Default: :class:`direct.data.transforms.fft2`. - backward_operator : Callable - The backward operator, e.g. some form of inverse FFT (centered or uncentered). - Default: :class:`direct.data.transforms.ifft2`. - kspace_key : KspaceKey - K-space key. Default: KspaceKey.KSPACE. - """ - super().__init__() - self.logger = logging.getLogger(type(self).__name__) - - if len(pad_shape) not in [2, 3]: - raise ValueError(f"Shape should be a list or tuple of two or three integers. Received: {pad_shape}.") - - self.shape = pad_shape - self.forward_operator = forward_operator - self.backward_operator = backward_operator - self.kspace_key = kspace_key - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Calls :class:`PadKspace`. - - Parameters - ---------- - sample: dict[str, Any] - Dict sample containing key `kspace`. - - Returns - ------- - dict[str, Any] - Cropped and masked sample. - """ - kspace = sample[self.kspace_key] # shape (coil, [slice or time], height, width, complex=2) - shape = kspace.shape - - sample["original_size"] = shape[1:-1] - - dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D - - backprojected_kspace = self.backward_operator(kspace, dim=dim) - backprojected_kspace = T.view_as_complex(backprojected_kspace) - - padded_backprojected_kspace = T.pad_tensor(backprojected_kspace, self.shape) - padded_backprojected_kspace = T.view_as_real(padded_backprojected_kspace) - - # shape (coil, [slice or time], height, width, complex=2) - sample[self.kspace_key] = self.forward_operator(padded_backprojected_kspace, dim=dim) # The padded kspace - - return sample - - -class ComputeZeroPadding(DirectTransform): - r"""Computes zero padding present in multi-coil kspace input. - - Zero-padding is computed from multi-coil kspace with no signal contribution, i.e. its magnitude - is really close to zero: - - .. math :: - - \text{padding} = \sum_{i=1}^{n_c} |y_i| < \frac{1}{n_x \cdot n_y} - \sum_{j=1}^{n_x \cdot n_y} \big\{\sum_{i=1}^{n_c} |y_i|\big\} * \epsilon. - """ - - def __init__( - self, - kspace_key: KspaceKey = KspaceKey.KSPACE, - padding_key: str = "padding", - eps: Optional[float] = 0.0001, - ) -> None: - """Inits :class:`ComputeZeroPadding`. - - Parameters - ---------- - kspace_key: KspaceKey - K-space key. Default: KspaceKey.KSPACE. - padding_key: str - Target key. Default: "padding". - eps: float - Epsilon to multiply sum of signals. If really high, probably no padding will be produced. Default: 0.0001. - """ - super().__init__() - self.kspace_key = kspace_key - self.padding_key = padding_key - self.eps = eps - - def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]: - """Updates sample with a key `padding_key` with value a binary tensor. - - Non-zero entries indicate samples in kspace with key `kspace_key` which have minor contribution, i.e. padding. - - Parameters - ---------- - sample : dict[str, Any] - Dict sample containing key `kspace_key`. - coil_dim : int - Coil dimension. Default: 0. - - Returns - ------- - sample : dict[str, Any] - Dict sample containing key `padding_key`. - """ - if self.eps is None: - return sample - shape = sample[self.kspace_key].shape - - kspace = T.modulus(sample[self.kspace_key].clone()).sum(coil_dim) - - if len(shape) == 5: # Check if 3D data - # Assumes that slice dim is 0 - kspace = kspace.sum(0) - - padding = (kspace < (torch.mean(kspace) * self.eps)).to(kspace.device) - - if len(shape) == 5: - padding = padding.unsqueeze(0) - - padding = padding.unsqueeze(coil_dim).unsqueeze(-1) - sample[self.padding_key] = padding - - return sample - - -class ApplyZeroPadding(DirectTransform): - """Applies zero padding present in multi-coil kspace input.""" - - def __init__(self, kspace_key: KspaceKey = KspaceKey.KSPACE, padding_key: str = "padding") -> None: - """Inits :class:`ApplyZeroPadding`. - - Parameters - ---------- - kspace_key: KspaceKey - K-space key. Default: KspaceKey.KSPACE. - padding_key: str - Target key. Default: "padding". - """ - super().__init__() - self.kspace_key = kspace_key - self.padding_key = padding_key - - def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]: - """Applies zero padding on `kspace_key` with value a binary tensor. - - Parameters - ---------- - sample : dict[str, Any] - Dict sample containing key `kspace_key`. - coil_dim : int - Coil dimension. Default: 0. - - Returns - ------- - sample : dict[str, Any] - Dict sample containing key `padding_key`. - """ - - sample[self.kspace_key] = T.apply_padding(sample[self.kspace_key], sample[self.padding_key]) - - return sample - - -class ReconstructionType(DirectEnum): - """Reconstruction method for :class:`ComputeImage` transform.""" - - IFFT = "ifft" - RSS = "rss" - COMPLEX = "complex" - COMPLEX_MOD = "complex_mod" - SENSE = "sense" - SENSE_MOD = "sense_mod" - - -class ComputeImageModule(DirectModule): - """Compute Image transform.""" - - def __init__( - self, - kspace_key: KspaceKey, - target_key: str, - backward_operator: Callable, - type_reconstruction: ReconstructionType = ReconstructionType.RSS, - ) -> None: - """Inits :class:`ComputeImageModule`. - - Parameters - ---------- - kspace_key: KspaceKey - K-space key. - target_key: str - Target key. - backward_operator: callable - The backward operator, e.g. some form of inverse FFT (centered or uncentered). - type_reconstruction: ReconstructionType - Type of reconstruction. Can be ReconstructionType.RSS, ReconstructionType.COMPLEX, - ReconstructionType.COMPLEX_MOD, ReconstructionType.SENSE, ReconstructionType.SENSE_MOD or - ReconstructionType.IFFT. Default: ReconstructionType.RSS. - """ - super().__init__() - self.backward_operator = backward_operator - self.kspace_key = kspace_key - self.target_key = target_key - self.type_reconstruction = type_reconstruction - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`ComputeImageModule`. - - Parameters - ---------- - sample: dict[str, Any] - Contains key kspace_key with value a torch.Tensor of shape (coil,\*spatial_dims, complex=2). - - Returns - ------- - sample: dict - Contains key target_key with value a torch.Tensor of shape (\*spatial_dims) if `type_reconstruction` is - ReconstructionType.RSS, ReconstructionType.COMPLEX_MOD, ReconstructionType.SENSE_MOD, - and of shape (\*spatial_dims, complex_dim=2) otherwise. - """ - kspace_data = sample[self.kspace_key] - dim = self.spatial_dims.TWO_D if kspace_data.ndim == 5 else self.spatial_dims.THREE_D - # Get complex-valued data solution - image = self.backward_operator(kspace_data, dim=dim) - if self.type_reconstruction == ReconstructionType.IFFT: - sample[self.target_key] = image - elif self.type_reconstruction in [ - ReconstructionType.COMPLEX, - ReconstructionType.COMPLEX_MOD, - ]: - sample[self.target_key] = image.sum(self.coil_dim) - elif self.type_reconstruction == ReconstructionType.RSS: - sample[self.target_key] = T.root_sum_of_squares(image, dim=self.coil_dim) - else: - if "sensitivity_map" not in sample: - raise ItemNotFoundException( - "sensitivity map", - "Sensitivity map is required for SENSE reconstruction.", - ) - sample[self.target_key] = T.complex_multiplication(T.conjugate(sample["sensitivity_map"]), image).sum( - self.coil_dim - ) - if self.type_reconstruction in [ - ReconstructionType.COMPLEX_MOD, - ReconstructionType.SENSE_MOD, - ]: - sample[self.target_key] = T.modulus(sample[self.target_key], self.complex_dim) - return sample - - -class EstimateBodyCoilImage(DirectTransform): - """Estimates body coil image.""" - - def __init__(self, mask_func: Callable, backward_operator: Callable, use_seed: bool = True) -> None: - """Inits :class:`EstimateBodyCoilImage'. - - Parameters - ---------- - mask_func: Callable - A function which creates a sampling mask of the appropriate shape. - backward_operator: callable - The backward operator, e.g. some form of inverse FFT (centered or uncentered). - use_seed: bool - If true, a pseudo-random number based on the filename is computed so that every slice of the volume get - the same mask every time. Default: True. - """ - super().__init__() - self.mask_func = mask_func - self.use_seed = use_seed - self.backward_operator = backward_operator - - def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]: - """Calls :class:`EstimateBodyCoilImage`. - - Parameters - ---------- - sample: dict[str, Any] - Contains key kspace_key with value a torch.Tensor of shape (coil, ..., complex=2). - coil_dim: int - Coil dimension. Default: 0. - - Returns - ---------- - sample: dict[str, Any] - Contains key `"body_coil_image`. - """ - kspace = sample["kspace"] - - # We need to create an ACS mask based on the shape of this kspace, as it can be cropped. - seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"]))) - kspace_shape = tuple(sample["kspace"].shape[-3:]) - acs_mask = self.mask_func(shape=kspace_shape, seed=seed, return_acs=True) - - kspace = acs_mask * kspace + 0.0 - dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D - acs_image = self.backward_operator(kspace, dim=dim) - - sample["body_coil_image"] = T.root_sum_of_squares(acs_image, dim=coil_dim) - return sample - - -class SensitivityMapType(DirectEnum): - ESPIRIT = "espirit" - RSS_ESTIMATE = "rss_estimate" - UNIT = "unit" - - -class EstimateSensitivityMapModule(DirectModule): - """Data Transformer for training MRI reconstruction models. - - Estimates sensitivity maps given masked k-space data using one of three methods: - - * Unit: unit sensitivity map in case of single coil acquisition. - * RSS-estimate: sensitivity maps estimated by using the root-sum-of-squares of the autocalibration-signal. - * ESPIRIT: sensitivity maps estimated with the ESPIRIT method [1]_. Note that this is currently not - implemented for 3D data, and attempting to use it in such cases will result in a NotImplementedError. - - References - ---------- - - .. [1] Uecker M, Lai P, Murphy MJ, Virtue P, Elad M, Pauly JM, Vasanawala SS, Lustig M. ESPIRiT--an eigenvalue - approach to autocalibrating parallel MRI: where SENSE meets GRAPPA. Magn Reson Med. 2014 Mar;71(3):990-1001. - doi: 10.1002/mrm.24751. PMID: 23649942; PMCID: PMC4142121. - """ - - def __init__( - self, - kspace_key: KspaceKey = KspaceKey.ACS_KSPACE, - backward_operator: Callable = T.ifft2, - type_of_map: Optional[SensitivityMapType] = SensitivityMapType.RSS_ESTIMATE, - gaussian_sigma: Optional[float] = None, - espirit_threshold: Optional[float] = 0.05, - espirit_kernel_size: Optional[int] = 6, - espirit_crop: Optional[float] = 0.95, - espirit_max_iters: Optional[int] = 30, - ) -> None: - """Inits :class:`EstimateSensitivityMapModule`. - - Parameters - ---------- - kspace_key: KspaceKey - K-space key to compute the ACS image from. If `kspace_key` is not `KspaceKey.ACS_KSPACE`, - the ACS mask should be provided in the sample. Default: KspaceKey.ACS_KSPACE. - backward_operator: callable - The backward operator, e.g. some form of inverse FFT (centered or uncentered). - type_of_map: SensitivityMapType, optional - Type of map to estimate. Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or - SensitivityMapType.ESPIRIT. Default: SensitivityMapType.RSS_ESTIMATE. - gaussian_sigma: float, optional - If non-zero, acs_image well be calculated - espirit_threshold: float, optional - Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. - Default: 0.05. - espirit_kernel_size: int, optional - Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. - Default: 6. - espirit_crop: float, optional - Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. - Default: 0.95. - espirit_max_iters: int, optional - Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30. - """ - super().__init__() - self.backward_operator = backward_operator - self.kspace_key = kspace_key - self.type_of_map = type_of_map - - # RSS estimate attributes - self.gaussian_sigma = gaussian_sigma - # Espirit attributes - if type_of_map == SensitivityMapType.ESPIRIT: - self.espirit_calibrator = EspiritCalibration( - backward_operator, - espirit_threshold, - espirit_kernel_size, - espirit_crop, - espirit_max_iters, - kspace_key, - ) - self.espirit_threshold = espirit_threshold - self.espirit_kernel_size = espirit_kernel_size - self.espirit_crop = espirit_crop - self.espirit_max_iters = espirit_max_iters - - def estimate_acs_image(self, sample: dict[str, Any], width_dim: int = -2) -> torch.Tensor: - """Estimates the autocalibration (ACS) image by sampling the k-space using the ACS mask. - - Parameters - ---------- - sample: dict[str, Any] - Sample dictionary, - width_dim: int - Dimension corresponding to width. Default: -2. - - Returns - ------- - acs_image: torch.Tensor - Estimate of the ACS image. - """ - kspace_data = sample[self.kspace_key] - - if self.kspace_key != KspaceKey.ACS_KSPACE: - if TransformKey.ACS_MASK not in sample: - raise ValueError("ACS mask is required for estimating ACS image from k-space but not found.") - kspace_data = kspace_data * sample[TransformKey.ACS_MASK] - - if self.gaussian_sigma == 0 or not self.gaussian_sigma: - kspace_acs = kspace_data + 0.0 # + 0.0 removes the sign of zeros. - else: - gaussian_mask = torch.linspace(-1, 1, kspace_data.size(width_dim), dtype=kspace_data.dtype) - gaussian_mask = torch.exp(-((gaussian_mask / self.gaussian_sigma) ** 2)) - gaussian_mask_shape = torch.ones(len(kspace_data.shape)).int() - gaussian_mask_shape[width_dim] = kspace_data.size(width_dim) - gaussian_mask = gaussian_mask.reshape(tuple(gaussian_mask_shape)) - kspace_acs = kspace_data * gaussian_mask + 0.0 - - # Get complex-valued data solution - # Shape (batch, [slice/time], coil, height, width, complex=2) - dim = self.spatial_dims.TWO_D if kspace_data.ndim == 5 else self.spatial_dims.THREE_D - acs_image = self.backward_operator(kspace_acs, dim=dim) - - return acs_image - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Calculates sensitivity maps for the input sample. - - Parameters - ---------- - sample: dict[str, Any] - Must contain key matching kspace_key with value a (complex) torch.Tensor - of shape (coil, height, width, complex=2). - - Returns - ------- - sample: dict[str, Any] - Sample with key "sensitivity_map" with value the estimated sensitivity map. - """ - kspace = sample[self.kspace_key] # shape (batch, coil, [slice/time], height, width, complex=2) - - if kspace.shape[self.coil_dim] == 1: - warnings.warn( - "Estimation of sensitivity map of Single-coil data. This warning will be displayed only once." - ) - if "sensitivity_map" in sample: - warnings.warn( - "`sensitivity_map` is given, but will be overwritten. This warning will be displayed only once." - ) - - if self.type_of_map == SensitivityMapType.UNIT: - sensitivity_map = torch.zeros(kspace.shape).float() - # Assumes complex channel is last - assert_complex(kspace, complex_last=True) - sensitivity_map[..., 0] = 1.0 - # Shape (batch, coil, [slice/time], height, width, complex=2) - sensitivity_map = sensitivity_map.to(kspace.device) - - elif self.type_of_map == SensitivityMapType.RSS_ESTIMATE: - # Shape (batch, coil, [slice/time], height, width, complex=2) - acs_image = self.estimate_acs_image(sample) - # Shape (batch, [slice/time], height, width) - acs_image_rss = T.root_sum_of_squares(acs_image, dim=self.coil_dim) - # Shape (batch, 1, [slice/time], height, width, 1) - acs_image_rss = acs_image_rss.unsqueeze(self.coil_dim).unsqueeze(self.complex_dim) - # Shape (batch, coil, [slice/time], height, width, complex=2) - sensitivity_map = T.safe_divide(acs_image, acs_image_rss) - else: - if sample[self.kspace_key].ndim > 5: - raise NotImplementedError( - "EstimateSensitivityMapModule is not yet implemented for " - "Espirit sensitivity map estimation for 3D data." - ) - sensitivity_map = self.espirit_calibrator(sample) - - sensitivity_map_norm = torch.sqrt( - (sensitivity_map**2).sum(self.complex_dim).sum(self.coil_dim) - ) # shape (batch, [slice/time], height, width) - sensitivity_map_norm = sensitivity_map_norm.unsqueeze(self.coil_dim).unsqueeze(self.complex_dim) - - sample[TransformKey.SENSITIVITY_MAP] = T.safe_divide(sensitivity_map, sensitivity_map_norm) - return sample - - -class AddBooleanKeysModule(DirectModule): - """Adds keys with boolean values to sample.""" - - def __init__(self, keys: list[str], values: list[bool]) -> None: - """Inits :class:`AddBooleanKeysModule`. - - Parameters - ---------- - keys : list[str] - A list of keys to be added. - values : list[bool] - A list of values corresponding to the keys. - """ - super().__init__() - self.keys = keys - self.values = values - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Adds boolean keys to the input sample dictionary. - - Parameters - ---------- - sample : dict[str, Any] - The input sample dictionary. - - Returns - ------- - dict[str, Any] - The modified sample with added boolean keys. - """ - for key, value in zip(self.keys, self.values): - sample[key] = value - - return sample - - -class CopyKeysModule(DirectModule): - """Copy keys to a new name from the sample if present.""" - - def __init__(self, keys: list[str], new_keys: list[str]) -> None: - """Inits :class:`CopyKeysModule`. - - Parameters - ---------- - keys: List[str] - Key(s) to copy. - new_keys: List[str] - Key(s) to create. - """ - super().__init__() - self.keys = keys - self.new_keys = new_keys - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`CopyKeysModule`. - - Parameters - ---------- - sample: Dict[str, Any] - Dictionary to look for keys and copy them with a new name. - - Returns - ------- - Dict[str, Any] - Dictionary with copied specified keys. - """ - for key, new_key in zip(self.keys, self.new_keys): - if key in sample: - if isinstance(sample[key], np.ndarray): - sample[new_key] = sample[key].copy() # Copy NumPy array - elif isinstance(sample[key], torch.Tensor): - sample[new_key] = sample[key].detach().clone() # Copy Torch tensor - else: - sample[new_key] = copy.deepcopy(sample[key]) - return sample - - -class CompressCoilModule(DirectModule): - """Compresses k-space coils using SVD.""" - - def __init__(self, kspace_key: KspaceKey, num_coils: int) -> None: - """Inits :class:`CompressCoilModule`. - - Parameters - ---------- - kspace_key : KspaceKey - K-space key. - num_coils : int - Number of coils to compress. - """ - super().__init__() - self.kspace_key = kspace_key - self.num_coils = num_coils - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Performs coil compression to input k-space. - - Parameters - ---------- - sample : dict[str, Any] - Dict sample containing key `kspace_key`. Assumes coil dimension is first axis. - - Returns - ------- - sample : dict[str, Any] - Dict sample with `kspace_key` compressed to num_coils. - """ - k_space = sample[self.kspace_key].clone() # shape (batch, coil, [slice/time], height, width, complex=2) - - if k_space.shape[1] <= self.num_coils: - return sample - - ndim = k_space.ndim - - k_space = torch.view_as_complex(k_space) - - if ndim == 6: # If 3D sample reshape slice into batch dimension as sensitivities are computed 2D - num_slice_or_time = k_space.shape[2] - k_space = k_space.permute(0, 2, 1, 3, 4) - k_space = k_space.reshape(k_space.shape[0] * num_slice_or_time, *k_space.shape[2:]) - - shape = k_space.shape - - # Reshape the k-space data to combine spatial dimensions - k_space_reshaped = k_space.reshape(shape[0], shape[1], -1) - - # Compute the coil combination matrix using Singular Value Decomposition (SVD) - U, _, _ = torch.linalg.svd(k_space_reshaped, full_matrices=False) - - # Select the top ncoils_new singular vectors from the decomposition - U_new = U[:, :, : self.num_coils] - - # Perform coil compression - compressed_k_space = torch.matmul(U_new.transpose(1, 2), k_space_reshaped) - - # Reshape the compressed k-space back to its original shape - compressed_k_space = compressed_k_space.reshape(shape[0], self.num_coils, *shape[2:]) - - if ndim == 6: - compressed_k_space = compressed_k_space.reshape( - shape[0] // num_slice_or_time, num_slice_or_time, self.num_coils, *shape[2:] - ).permute(0, 2, 1, 3, 4) - - compressed_k_space = torch.view_as_real(compressed_k_space) - sample[self.kspace_key] = compressed_k_space # shape (batch, new coil, [slice/time], height, width, complex=2) - - return sample - - -class DeleteKeysModule(DirectModule): - """Remove keys from the sample if present.""" - - def __init__(self, keys: list[str]) -> None: - """Inits :class:`DeleteKeys`. - - Parameters - ---------- - keys: list[str] - Key(s) to delete. - """ - super().__init__() - self.keys = keys - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`DeleteKeys`. - - Parameters - ---------- - sample: dict[str, Any] - Dictionary to look for keys and remove them. - - Returns - ------- - dict[str, Any] - Dictionary with deleted specified keys. - """ - for key in self.keys: - if key in sample: - del sample[key] - - return sample - - -class IndexSelectionMode(DirectEnum): - RANDOM = "random" - CUSTOM = "custom" - RANGE = "range" - - -class IndexSelectionModule(DirectModule): - """Randomly selects indices from the sample. - - Parameters - ---------- - key: TransformKey - Key to select indices from. - mode: IndexSelectionMode - Mode of index selection. Can be IndexSelectionMode.RANDOM, IndexSelectionMode.CUSTOM or - IndexSelectionMode.RANGE. Default: IndexSelectionMode.CUSTOM. - num_indices: int - Number of indices to select. - out_key: TransformKey, optional - Key to store the selected indices. If None, the indices are stored in the same key. - Default: None. - index_dim: int - Dimension along which to select indices. Default: 1. - use_seed: bool - If true, a pseudo-random number based on the filename is computed so that every slice of the volume get - the same mask every time. Default: True - """ - - def __init__( - self, - key: TransformKey, - mode: IndexSelectionMode = IndexSelectionMode.CUSTOM, - indices: Optional[list[int]] = None, - num_indices: Optional[int] = None, - out_key: Optional[TransformKey] = None, - index_dim: int = 0, - use_seed: bool = True, - ) -> None: - """Inits :class:`IndexSelection`. - - Parameters - ---------- - key: TransformKey - Key to select indices from. - mode: IndexSelectionMode - Mode of index selection. Can be IndexSelectionMode.RANDOM, IndexSelectionMode.CUSTOM or - IndexSelectionMode.RANGE. Default: IndexSelectionMode.CUSTOM. - indices: list[int], optional - List of indices to select if mode is IndexSelectionMode.CUSTOM or range if mode is - IndexSelectionMode.RANGE. Default: None. - num_indices: int - Number of indices to select if mode is IndexSelectionMode.RANDOM. Default: None. - out_key: TransformKey, optional - Key to store the selected indices. If None, the indices are stored in the same key. - Default: None. - index_dim: int - Dimension along which to select indices. Default: 1. - use_seed: bool - If true, a pseudo-random number based on the filename is computed so that every slice of the volume get - the same mask every time. Default: True - """ - super().__init__() - self.key = key - self.out_key = out_key if out_key is not None else key - self.mode = mode - self.indices = indices - self.num_indices = num_indices - self.index_dim = index_dim - self.use_seed = use_seed - self.rng = np.random.RandomState() - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`IndexSelection`. - - Parameters - ---------- - sample: dict[str, Any] - Dictionary to look for key and select indices from. - - Returns - ------- - dict[str, Any] - Dictionary with randomly selected indices. - """ - if self.key not in sample: - return sample - - if self.mode == IndexSelectionMode.RANDOM: - seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"]))) - with temp_seed(self.rng, seed): - num_to_keep = max(min(self.num_indices, sample[self.key].shape[self.index_dim]), 1) - start = self.rng.randint(0, sample[self.key].shape[self.index_dim] - num_to_keep) - keep_indices = torch.arange(start, start + num_to_keep, device=sample[self.key].device) - else: - if self.mode == IndexSelectionMode.CUSTOM: - keep_indices = torch.tensor( - [idx for idx in self.indices if np.abs(idx) < sample[self.key].shape[self.index_dim]], - device=sample[self.key].device, - ) - else: - keep_indices = torch.arange(self.indices[0], self.indices[1], device=sample[self.key].device) - num_to_keep = len(keep_indices) - - sample[self.out_key] = sample[self.key].index_select(self.index_dim, keep_indices) - - if num_to_keep == 1: - sample[self.out_key] = sample[self.out_key].squeeze(self.index_dim) - - return sample - - -class DropIndexModule(DirectModule): - """Drop indices from the sample. - - Parameters - ---------- - keys: list[TransformKey] - Key(s) to drop indices from. - index: int - Index to drop. - index_dim: int, list[int] - Dimension(s) along which to drop indices. If a list, must have the same length as `keys`. Default: 1. - store_deleted_keys: list[TransformKey], optional - Key(s) to store the deleted indices. If None, the deleted indices are not stored. If the length does not - match `keys`, the remaining keys are set to None. Default: None. - """ - - def __init__( - self, - keys: list[TransformKey], - index: int, - index_dim: int | list[int] = 1, - store_deleted_keys: Optional[list[TransformKey]] = None, - ) -> None: - """Inits :class:`DropIndexModule`. - - Parameters - ---------- - keys: list[TransformKey] - Key(s) to drop indices from. - index: int - Index to drop. - index_dim: int, list[int] - Dimension(s) along which to drop indices. If a list, must have the same length as `keys`. Default: 1. - store_deleted_keys: list[TransformKey], optional - Key(s) to store the deleted indices. If None, the deleted indices are not stored. If the length does not - match `keys`, the remaining keys are set to None. Default: None. - """ - super().__init__() - self.keys = keys - self.index = index - self.index_dim = [index_dim] * len(keys) if isinstance(index_dim, int) else index_dim - self.store_deleted_keys = store_deleted_keys - if self.store_deleted_keys is not None and len(keys) > len(self.store_deleted_keys): - self.store_deleted_keys = store_deleted_keys + [None] * (len(keys) - len(self.store_deleted_keys)) - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`DropIndexModule`. - - Parameters - ---------- - sample: dict[str, Any] - Dictionary to look for key and drop indices from. - - Returns - ------- - dict[str, Any] - Dictionary with dropped index. - """ - - for i, key in enumerate(self.keys): - if key not in sample: - continue - # This might be helpful, for instance, in case a single mask is used for all time frames - if sample[key].shape[self.index_dim[i]] == 1: - continue - if self.store_deleted_keys is not None: - deleted_key = self.store_deleted_keys[i] - if deleted_key: - sample[deleted_key] = sample[key].index_select( - self.index_dim[i], - torch.tensor( - [idx for idx in range(sample[key].shape[self.index_dim[i]]) if idx == self.index], - device=sample[key].device, - ), - ) - sample[key] = sample[key].index_select( - self.index_dim[i], - torch.tensor( - [idx for idx in range(sample[key].shape[self.index_dim[i]]) if idx != self.index], - device=sample[key].device, - ), - ) - - return sample - - -class SqueezeKeyModule(DirectModule): - """Squeeze the specified key(s) in the sample. - - Parameters - ---------- - keys: TransformKey - Key(s) to squeeze. - dim: int - Dimension to squeeze. - """ - - def __init__(self, keys: TransformKey, dim: int) -> None: - """Inits :class:`SqueezeKeyModule`. - - Parameters - ---------- - keys: TransformKey - Key(s) to squeeze. - dim: int - Dimension to squeeze. - """ - super().__init__() - self.keys = keys - self.dim = dim - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`SqueezeKeyModule`. - - Parameters - ---------- - sample: dict[str, Any] - Dictionary to look for keys to squeeze. - - Returns - ------- - dict[str, Any] - Dictionary with squeezed specified keys. - """ - for key in self.keys: - if key in sample: - sample[key] = sample[key].squeeze(self.dim) - return sample - - -class RenameKeysModule(DirectModule): - """Rename keys from the sample if present.""" - - def __init__(self, old_keys: list[str], new_keys: list[str]) -> None: - """Inits :class:`RenameKeys`. - - Parameters - ---------- - old_keys: list[str] - Key(s) to rename. - new_keys: list[str] - Key(s) to replace old keys. - """ - super().__init__() - self.old_keys = old_keys - self.new_keys = new_keys - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`RenameKeys`. - - Parameters - ---------- - sample: dict[str, Any] - Dictionary to look for keys and rename them. - - Returns - ------- - dict[str, Any] - Dictionary with renamed specified keys. - """ - for old_key, new_key in zip(self.old_keys, self.new_keys): - if old_key in sample: - sample[new_key] = sample.pop(old_key) - - return sample - - -class PadCoilDimensionModule(DirectModule): - """Pad the coils by zeros to a given number of coils. - - Useful if you want to collate volumes with different coil dimension. - """ - - def __init__( - self, - pad_coils: Optional[int] = None, - key: str = "masked_kspace", - coil_dim: int = 1, - ) -> None: - """Inits :class:`PadCoilDimensionModule`. - - Parameters - ---------- - pad_coils: int, optional - Number of coils to pad to. Default: None. - key: str - Key to pad in sample. Default: "masked_kspace". - coil_dim: int - Coil dimension along which the pad will be done. Default: 0. - """ - super().__init__() - self.num_coils = pad_coils - self.key = key - self.coil_dim = coil_dim - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`PadCoilDimensionModule`. - - Parameters - ---------- - sample: dict[str, Any] - Dictionary with key `self.key`. - - Returns - ------- - sample: dict[str, Any] - Dictionary with padded coils of sample[self.key] if self.num_coils is not None. - """ - if not self.num_coils: - return sample - - if self.key not in sample: - return sample - - data = sample[self.key] - - curr_num_coils = data.shape[self.coil_dim] - if curr_num_coils > self.num_coils: - raise ValueError( - f"Tried to pad to {self.num_coils} coils, but already have {curr_num_coils} for " - f"{sample['filename']}." - ) - if curr_num_coils == self.num_coils: - return sample - - shape = data.shape - num_coils = shape[self.coil_dim] - padding_data_shape = list(shape).copy() - padding_data_shape[self.coil_dim] = max(self.num_coils - num_coils, 0) - zeros = torch.zeros(padding_data_shape, dtype=data.dtype, device=data.device) - sample[self.key] = torch.cat([zeros, data], dim=self.coil_dim) - - return sample - - -class ComputeScalingFactorModule(DirectModule): - """Calculates scaling factor. - - Scaling factor is for the input data based on either to the percentile or to the maximum of `normalize_key`. - """ - - def __init__( - self, - normalize_key: Union[None, TransformKey] = TransformKey.MASKED_KSPACE, - percentile: Union[None, float] = 0.99, - scaling_factor_key: TransformKey = TransformKey.SCALING_FACTOR, - ) -> None: - """Inits :class:`ComputeScalingFactorModule`. - - Parameters - ---------- - normalize_key : TransformKey or None - Key name to compute the data for. If the maximum has to be computed on the ACS, ensure the reconstruction - on the ACS is available (typically `body_coil_image`). Default: "masked_kspace". - percentile : float or None - Rescale data with the given percentile. If None, the division is done by the maximum. Default: 0.99. - scaling_factor_key : TransformKey - Name of how the scaling factor will be stored. Default: "scaling_factor". - """ - super().__init__() - self.normalize_key = normalize_key - self.percentile = percentile - self.scaling_factor_key = scaling_factor_key - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`ComputeScalingFactorModule`. - - Parameters - ---------- - sample: dict[str, Any] - Sample with key `normalize_key` to compute scaling_factor. - - Returns - ------- - sample: dict[str, Any] - Sample with key `scaling_factor_key`. - """ - if self.normalize_key == "scaling_factor": # This is a real-valued given number - scaling_factor = sample["scaling_factor"] - elif not self.normalize_key: - kspace = sample["masked_kspace"] - scaling_factor = torch.tensor([1.0] * kspace.size(0), device=kspace.device, dtype=kspace.dtype) - else: - data = sample[self.normalize_key] - scaling_factor: Union[list, torch.Tensor] = [] - # Compute the maximum and scale the input - if self.percentile: - for _ in range(data.size(0)): - # Used in case the k-space is padded (e.g. for batches) - non_padded_coil_data = data[_][data[_].sum(dim=tuple(range(1, data[_].ndim))).bool()] - tview = -1.0 * T.modulus(non_padded_coil_data).view(-1) - s, _ = torch.kthvalue(tview, int((1 - self.percentile) * tview.size()[0]) + 1) - scaling_factor += [-1.0 * s] - scaling_factor = torch.tensor(scaling_factor, dtype=data.dtype, device=data.device) - else: - scaling_factor = T.modulus(data).amax(dim=list(range(data.ndim))[1:-1]) - sample[self.scaling_factor_key] = scaling_factor - return sample - - -class NormalizeModule(DirectModule): - """Normalize the input data.""" - - def __init__( - self, - scaling_factor_key: TransformKey = TransformKey.SCALING_FACTOR, - keys_to_normalize: Optional[list[TransformKey]] = None, - ) -> None: - """Inits :class:`NormalizeModule`. - - Parameters - ---------- - scaling_factor_key : TransformKey - Name of scaling factor key expected in sample. Default: 'scaling_factor'. - """ - super().__init__() - self.scaling_factor_key = scaling_factor_key - - self.keys_to_normalize = ( - [ - "masked_kspace", - "target", - "kspace", - "body_coil_image", # sensitivity_map does not require normalization. - "initial_image", - "initial_kspace", - ] - if keys_to_normalize is None - else keys_to_normalize - ) - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`NormalizeModule`. - - Parameters - ---------- - sample: dict[str, Any] - Sample to normalize. - - Returns - ------- - sample: dict[str, Any] - Sample with normalized values if their respective key is in `keys_to_normalize` and key - `scaling_factor_key` exists in sample. - """ - scaling_factor = sample.get(self.scaling_factor_key, None) - # Normalize data - if scaling_factor is not None: - for key in sample.keys(): - if key not in self.keys_to_normalize: - continue - sample[key] = T.safe_divide( - sample[key], - scaling_factor.reshape(-1, *[1 for _ in range(sample[key].ndim - 1)]), - ) - - sample["scaling_diff"] = 0.0 - return sample - - -class WhitenDataModule(DirectModule): - """Whitens complex data Module.""" - - def __init__(self, epsilon: float = 1e-10, key: str = "complex_image") -> None: - """Inits :class:`WhitenDataModule`. - - Parameters - ---------- - epsilon: float - Default: 1e-10. - key: str - Key to whiten. Default: "complex_image". - """ - super().__init__() - self.epsilon = epsilon - self.key = key - - def complex_whiten(self, complex_image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Whiten complex image. - - Parameters - ---------- - complex_image: torch.Tensor - Complex image tensor to whiten. - - Returns - ------- - mean, std, whitened_image: tuple[torch.Tensor, torch.Tensor, torch.Tensor] - """ - # From: https://github.com/facebookresearch/fastMRI - # blob/da1528585061dfbe2e91ebbe99a5d4841a5c3f43/banding_removal/fastmri/data/transforms.py#L464 # noqa - real = complex_image[..., 0] - imag = complex_image[..., 1] - - # Center around mean. - mean = complex_image.mean() - centered_complex_image = complex_image - mean - - # Determine covariance between real and imaginary. - n_elements = real.nelement() - real_real = (real.mul(real).sum() - real.mean().mul(real.mean())) / n_elements - real_imag = (real.mul(imag).sum() - real.mean().mul(imag.mean())) / n_elements - imag_imag = (imag.mul(imag).sum() - imag.mean().mul(imag.mean())) / n_elements - eig_input = torch.Tensor([[real_real, real_imag], [real_imag, imag_imag]]) - - # Remove correlation by rotating around covariance eigenvectors. - eig_values, eig_vecs = torch.linalg.eig(eig_input) - - # Scale by eigenvalues for unit variance. - std = (eig_values.real + self.epsilon).sqrt() - whitened_image = torch.matmul(centered_complex_image, eig_vecs.real) / std - - return mean, std, whitened_image - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Forward pass of :class:`WhitenDataModule`. - - Parameters - ---------- - sample: dict[str, Any] - Sample with key `key`. - - Returns - ------- - sample: dict[str, Any] - Sample with value of `key` whitened. - """ - _, _, whitened_image = self.complex_whiten(sample[self.key]) - sample[self.key] = whitened_image - return sample - - -class AddTargetAcceleration(DirectTransform): - """This will replace the acceleration factor in the sample with the target acceleration factor.""" - - def __init__(self, target_acceleration: float): - super().__init__() - self.target_acceleration = target_acceleration - - def __call__(self, sample: dict[str, Any]): - sample["acceleration"][:] = self.target_acceleration - return sample - - -class ModuleWrapper: - class SubWrapper: - def __init__(self, transform: Callable, toggle_dims: bool) -> None: - self.toggle_dims = toggle_dims - self._transform = transform - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - if self.toggle_dims: - for k, v in sample.items(): - if isinstance(v, (torch.Tensor, np.ndarray)): - sample[k] = v[None] - else: - sample[k] = [v] - - sample = self._transform.forward(sample) - - if self.toggle_dims: - for k, v in sample.items(): - if isinstance(v, (torch.Tensor, np.ndarray)): - sample[k] = v.squeeze(0) - else: - sample[k] = v[0] - - return sample - - def __repr__(self) -> str: - return self._transform.__repr__() - - def __init__(self, module: Callable, toggle_dims: bool) -> None: - self._module = module - self.toggle_dims = toggle_dims - - def __call__(self, *args, **kwargs) -> SubWrapper: - return self.SubWrapper(self._module(*args, **kwargs), toggle_dims=self.toggle_dims) - - -ApplyMask = ModuleWrapper(ApplyMaskModule, toggle_dims=False) -ComputeImage = ModuleWrapper(ComputeImageModule, toggle_dims=True) -EstimateSensitivityMap = ModuleWrapper(EstimateSensitivityMapModule, toggle_dims=True) -CopyKeys = ModuleWrapper(CopyKeysModule, toggle_dims=False) -DeleteKeys = ModuleWrapper(DeleteKeysModule, toggle_dims=False) -RenameKeys = ModuleWrapper(RenameKeysModule, toggle_dims=False) -IndexSelection = ModuleWrapper(IndexSelectionModule, toggle_dims=False) -DropIndex = ModuleWrapper(DropIndexModule, toggle_dims=False) -SqueezeKey = ModuleWrapper(SqueezeKeyModule, toggle_dims=False) -CompressCoil = ModuleWrapper(CompressCoilModule, toggle_dims=True) -PadCoilDimension = ModuleWrapper(PadCoilDimensionModule, toggle_dims=True) -ComputeScalingFactor = ModuleWrapper(ComputeScalingFactorModule, toggle_dims=True) -Normalize = ModuleWrapper(NormalizeModule, toggle_dims=False) -WhitenData = ModuleWrapper(WhitenDataModule, toggle_dims=False) -GaussianMaskSplitter = ModuleWrapper(GaussianMaskSplitterModule, toggle_dims=True) -UniformMaskSplitter = ModuleWrapper(UniformMaskSplitterModule, toggle_dims=True) -Displacement = ModuleWrapper(DisplacementModule, toggle_dims=True) -RandomElasticDeformation = ModuleWrapper(RandomElasticDeformationModule, toggle_dims=True) - - -class ToTensor(DirectTransform): - """Transforms all np.array-like values in sample to torch.tensors.""" - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Calls :class:`ToTensor`. - - Parameters - ---------- - sample: dict[str, Any] - Contains key 'kspace' with value a np.array of shape (coil, height, width) (2D) - or (coil, slice, height, width) (3D) - - Returns - ------- - sample: dict[str, Any] - Contains key 'kspace' with value a torch.Tensor of shape (coil, height, width) (2D) - or (coil, slice, height, width) (3D) - """ - - ndim = sample["kspace"].ndim - 1 - - if ndim not in [2, 3]: - raise ValueError(f"Can only cast 2D and 3D data (+coil) to tensor. Got {ndim}.") - - # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2) - sample["kspace"] = T.to_tensor(sample["kspace"]).float() - # Sensitivity maps are not necessarily available in the dataset. - if "initial_kspace" in sample: - # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2) - sample["initial_kspace"] = T.to_tensor(sample["initial_kspace"]).float() - if "initial_image" in sample: - # Shape: 2D: (height, width), 3D: (slice, height, width) - sample["initial_image"] = T.to_tensor(sample["initial_image"]).float() - - if "sensitivity_map" in sample: - # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2) - sample["sensitivity_map"] = T.to_tensor(sample["sensitivity_map"]).float() - if "target" in sample: - # Shape: 2D: (coil, height, width), 3D: (coil, slice, height, width) - sample["target"] = torch.from_numpy(sample["target"]).float() - if "sampling_mask" in sample: - sample["sampling_mask"] = torch.from_numpy(sample["sampling_mask"]).bool() - if "acs_mask" in sample: - sample["acs_mask"] = torch.from_numpy(sample["acs_mask"]).bool() - if "scaling_factor" in sample: - sample["scaling_factor"] = torch.tensor(sample["scaling_factor"]).float() - if "loglikelihood_scaling" in sample: - # Shape: (coil, ) - sample["loglikelihood_scaling"] = torch.from_numpy(np.asarray(sample["loglikelihood_scaling"])).float() - - return sample - - -class RegistrationSimulateReferenceType(DirectEnum): - FROM_KEY = "from_key" - ELASTIC = "elastic" - - -# pylint: disable=too-many-arguments -def build_supervised_mri_transforms( - forward_operator: Callable, - backward_operator: Callable, - mask_func: Optional[Callable], - target_acceleration: Optional[float] = None, - crop: Optional[Union[tuple[int, int], str]] = None, - crop_type: Optional[str] = "uniform", - rescale: Optional[Union[tuple[int, int], list[int]]] = None, - rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST, - rescale_2d_if_3d: Optional[bool] = False, - pad: Optional[Union[tuple[int, int], list[int]]] = None, - image_center_crop: bool = True, - random_rotation_degrees: Optional[Sequence[int]] = (-90, 90), - random_rotation_probability: float = 0.0, - random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM, - random_flip_probability: float = 0.0, - random_reverse_probability: float = 0.0, - padding_eps: float = 0.0001, - estimate_body_coil_image: bool = False, - estimate_sensitivity_maps: bool = True, - sensitivity_maps_type: SensitivityMapType = SensitivityMapType.RSS_ESTIMATE, - sensitivity_maps_gaussian: Optional[float] = None, - sensitivity_maps_espirit_threshold: Optional[float] = 0.05, - sensitivity_maps_espirit_kernel_size: Optional[int] = 6, - sensitivity_maps_espirit_crop: Optional[float] = 0.95, - sensitivity_maps_espirit_max_iters: Optional[int] = 30, - use_acs_as_mask: bool = False, - delete_acs: bool = True, - delete_kspace: bool = True, - image_recon_type: ReconstructionType = ReconstructionType.RSS, - compress_coils: Optional[int] = None, - pad_coils: Optional[int] = None, - scaling_key: TransformKey = TransformKey.MASKED_KSPACE, - scale_percentile: Optional[float] = 0.99, - registration: bool = False, - registration_simulate_reference: Optional[RegistrationSimulateReferenceType] = None, - registration_simulate_elastic_sigma: float = 3.0, - registration_simulate_elastic_points: int = 3, - registration_simulate_elastic_rotate: float = 0.0, - registration_simulate_elastic_zoom: float = 0.0, - registration_estimate_displacement: bool = True, - registration_simulate_reference_from_key_index: int = 0, - registration_moving_key: TransformKey = TransformKey.TARGET, - demons_filter_type: DemonsFilterType = DemonsFilterType.SYMMETRIC_FORCES, - demons_num_iterations: int = 100, - demons_smooth_displacement_field: bool = True, - demons_standard_deviations: float = 1.5, - demons_intensity_difference_threshold: Optional[float] = None, - demons_maximum_rms_error: Optional[float] = None, - use_seed: bool = True, -) -> DirectTransform: - r"""Builds supervised MRI transforms. - - More specifically, the following transformations are applied: - - * Converts input to (complex-valued) tensor. - * Applies k-space (center) crop if requested. - * Applies k-space rescaling if requested. - * Applies k-space padding if requested. - * Applies random augmentations (rotation, flip, reverse) if requested. - * Adds a sampling mask if `mask_func` is defined. - * Compreses the coil dimension if requested. - * Pads the coil dimension if requested. - * Adds coil sensitivities and / or the body coil_image - * Masks the fully sampled k-space, if there is a mask function or a mask in the sample. - * Computes a scaling factor based on the masked k-space and normalizes data. - * Computes a target (image). - * Deletes the acs mask and the fully sampled k-space if requested. - - Parameters - ---------- - forward_operator : Callable - The forward operator, e.g. some form of FFT (centered or uncentered). - backward_operator : Callable - The backward operator, e.g. some form of inverse FFT (centered or uncentered). - mask_func : Callable or None - A function which creates a sampling mask of the appropriate shape. - target_acceleration : float, optional - Target acceleration factor. Default: None. - crop : tuple[int, int] or str, Optional - If not None, this will transform the "kspace" to an image domain, crop it, and transform it back. - If a tuple of integers is given then it will crop the backprojected kspace to that size. If - "reconstruction_size" is given, then it will crop the backprojected kspace according to it, but - a key "reconstruction_size" must be present in the sample. Default: None. - crop_type : Optional[str] - Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform". - rescale : tuple or list, optional - If not None, this will transform the "kspace" to the image domain, rescale it, and transform it back. - Must correspond to (height, width). This is ignored if `rescale` is None. Default: None. - It is not recommended to be used in combination with `crop`. - rescale_mode : RescaleMode - Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, - RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are - supported for 2D or 3D data. Default: RescaleMode.NEAREST. - rescale_2d_if_3d : bool, optional - If True and k-space data is 3D, rescaling will be done only on the height - and width dimensions, by combining the slice/time dimension with the batch dimension. - This is ignored if `rescale` is None. Default: False. - pad : tuple or list, optional - If not None, this will zero-pad the "kspace" to the given size. Must correspond to (height, width) - or (slice/time, height, width). Default: None. - image_center_crop : bool - If True the backprojected kspace will be cropped around the center, otherwise randomly. - This will be ignored if `crop` is None. Default: True. - random_rotation_degrees : Sequence[int], optional - Default: (-90, 90). - random_rotation_probability : float, optional - If greater than 0.0, random rotations will be applied of `random_rotation_degrees` degrees, with probability - `random_rotation_probability`. Default: 0.0. - random_flip_type : RandomFlipType, optional - Default: RandomFlipType.RANDOM. - random_flip_probability : float, optional - If greater than 0.0, random rotation of `random_flip_type` type, with probability `random_flip_probability`. - Default: 0.0. - random_reverse_probability : float - If greater than 0.0, will perform random reversion along the time or slice dimension (2) with probability - `random_reverse_probability`. Default: 0.0. - padding_eps: float - Padding epsilon. Default: 0.0001. - estimate_body_coil_image : bool - Estimate body coil image. Default: False. - estimate_sensitivity_maps : bool - Estimate sensitivity maps using the acs region. Default: True. - sensitivity_maps_type: sensitivity_maps_type - Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or SensitivityMapType.ESPIRIT. - Will be ignored if `estimate_sensitivity_maps` is False. Default: SensitivityMapType.RSS_ESTIMATE. - sensitivity_maps_gaussian : float - Optional sigma for gaussian weighting of sensitivity map. - sensitivity_maps_espirit_threshold : float, optional - Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. - Default: 0.05. - sensitivity_maps_espirit_kernel_size : int, optional - Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 6. - sensitivity_maps_espirit_crop : float, optional - Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 0.95. - sensitivity_maps_espirit_max_iters : int, optional - Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30. - use_acs_as_mask : bool - If True, will use the acs region as the mask. Default: False. - delete_acs : bool - If True will delete key `acs_mask`. Default: True. - delete_kspace : bool - If True will delete key `kspace` (fully sampled k-space). Default: True. - image_recon_type : ReconstructionType - Type to reconstruct target image. Default: ReconstructionType.RSS. - compress_coils : int, optional - Number of coils to compress input k-space. It is not recommended to be used in combination with `pad_coils`. - Default: None. - pad_coils : int - Number of coils to pad data to. - scaling_key : TransformKey - Key in sample to scale scalable items in sample. Default: TransformKey.MASKED_KSPACE. - scale_percentile : float, optional - Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99. - registration : bool - If True, will compute a displacement field between the target and the moving image. Default: False. - registration_simulate_reference : RegistrationSimulateReferenceType - If not None, will simulate a reference image for displacement field computation. Otherwise, this expects a key - in the sample. Can be RegistrationSimulateReferenceType.FROM_KEY or RegistrationSimulateReferenceType.ELASTIC. - Default: None. - registration_simulate_elastic_sigma : float - Standard deviation for the elastic simulation. Default: 3.0. - registration_simulate_elastic_points : int - Number of points for the elastic simulation. Default: 3. - registration_simulate_elastic_rotate : float - Rotation for the elastic simulation. Default: 0.0. - registration_estimate_displacement : bool - If True, will estimate the displacement field between the target and the moving image using the - demons algorithm. Default: True - registration_simulate_elastic_zoom : float - Zoom for the elastic simulation. Default: 0.0. - registration_simulate_reference_from_key_index : int - Index to drop from the key to simulate the reference image. Default: 0. - demons_filter_type : DemonsFilterType - Type of filter to apply to the displacement field. Default: DemonsFilterType.SYMMETRIC_FORCES. - demons_num_iterations : int - Number of iterations for the demons algorithm. Default: 100. - demons_smooth_displacement_field : bool - If True, will smooth the displacement field. Default: True. - demons_standard_deviations : float - Standard deviation for the smoothing of the displacement field. Default: 1.5. - demons_intensity_difference_threshold : float, optional - Intensity difference threshold for the demons algorithm. Default: None. - demons_maximum_rms_error : float, optional - Maximum RMS error for the demons algorithm. Default: None. - use_seed : bool - If true, a pseudo-random number based on the filename is computed so that every slice of the volume get - the same mask every time. Default: True. - - Returns - ------- - DirectTransform - An MRI transformation object. - """ - mri_transforms: list[Callable] = [ToTensor()] - if crop: - mri_transforms += [ - CropKspace( - crop=crop, - forward_operator=forward_operator, - backward_operator=backward_operator, - image_space_center_crop=image_center_crop, - random_crop_sampler_type=crop_type, - random_crop_sampler_use_seed=use_seed, - ) - ] - if rescale: - mri_transforms += [ - RescaleKspace( - shape=rescale, - forward_operator=forward_operator, - backward_operator=backward_operator, - rescale_mode=rescale_mode, - rescale_2d_if_3d=rescale_2d_if_3d, - kspace_key=KspaceKey.KSPACE, - ) - ] - if pad: - mri_transforms += [ - PadKspace( - pad_shape=pad, - forward_operator=forward_operator, - backward_operator=backward_operator, - kspace_key=KspaceKey.KSPACE, - ) - ] - if random_rotation_probability > 0.0: - mri_transforms += [ - RandomRotation( - degrees=random_rotation_degrees, - p=random_rotation_probability, - keys_to_rotate=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP), - ) - ] - if random_flip_probability > 0.0: - mri_transforms += [ - RandomFlip( - flip=random_flip_type, - p=random_flip_probability, - keys_to_flip=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP), - ) - ] - if random_reverse_probability > 0.0: - mri_transforms += [ - RandomReverse( - p=random_reverse_probability, - keys_to_reverse=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP), - ) - ] - if padding_eps > 0.0: - mri_transforms += [ - ComputeZeroPadding(KspaceKey.KSPACE, "padding", padding_eps), - ApplyZeroPadding(KspaceKey.KSPACE, "padding"), - ] - if mask_func: - mri_transforms += [ - CreateSamplingMask( - mask_func, - shape=(None if (isinstance(crop, str)) else crop), - use_seed=use_seed, - return_acs=estimate_sensitivity_maps, - ), - ] - if use_acs_as_mask: - mri_transforms += [CopyKeys(keys=[TransformKey.ACS_MASK], new_keys=[TransformKey.SAMPLING_MASK])] - if target_acceleration: - mri_transforms += [AddTargetAcceleration(target_acceleration)] - if compress_coils: - mri_transforms += [CompressCoil(num_coils=compress_coils, kspace_key=KspaceKey.KSPACE)] - if pad_coils: - mri_transforms += [PadCoilDimension(pad_coils=pad_coils, key=KspaceKey.KSPACE)] - - if estimate_body_coil_image and mask_func is not None: - mri_transforms.append(EstimateBodyCoilImage(mask_func, backward_operator=backward_operator, use_seed=use_seed)) - mri_transforms += [ - ApplyMask( - sampling_mask_key=TransformKey.ACS_MASK, - input_kspace_key=KspaceKey.KSPACE, - target_kspace_key=KspaceKey.ACS_KSPACE, - ), - ] - if estimate_sensitivity_maps: - mri_transforms += [ - EstimateSensitivityMap( - kspace_key=KspaceKey.ACS_KSPACE, - backward_operator=backward_operator, - type_of_map=sensitivity_maps_type, - gaussian_sigma=sensitivity_maps_gaussian, - espirit_threshold=sensitivity_maps_espirit_threshold, - espirit_kernel_size=sensitivity_maps_espirit_kernel_size, - espirit_crop=sensitivity_maps_espirit_crop, - espirit_max_iters=sensitivity_maps_espirit_max_iters, - ) - ] - mri_transforms += [ - ApplyMask( - sampling_mask_key=TransformKey.SAMPLING_MASK, - input_kspace_key=KspaceKey.KSPACE, - target_kspace_key=KspaceKey.MASKED_KSPACE, - ), - ] - if registration: - if registration_simulate_reference is not None: - mri_transforms += [ - DropIndex( - keys=[ - TransformKey.KSPACE, - TransformKey.ACS_KSPACE, - TransformKey.MASKED_KSPACE, - TransformKey.ACS_MASK, - TransformKey.SAMPLING_MASK, - TransformKey.PADDING, - TransformKey.SENSITIVITY_MAP, - ], - index=registration_simulate_reference_from_key_index, - index_dim=1, - store_deleted_keys=[TransformKey.REFERENCE_KSPACE], - ) - ] - mri_transforms += [ - ComputeScalingFactor( - normalize_key=scaling_key, percentile=scale_percentile, scaling_factor_key=TransformKey.SCALING_FACTOR - ), - Normalize( - scaling_factor_key=TransformKey.SCALING_FACTOR, - keys_to_normalize=[ - KspaceKey.ACS_KSPACE, - KspaceKey.KSPACE, - KspaceKey.MASKED_KSPACE, - KspaceKey.REFERENCE_KSPACE, - ], # Only these two keys are in the sample here - ), - ] - mri_transforms += [ - ComputeImage( - kspace_key=KspaceKey.KSPACE, - target_key=TransformKey.TARGET, - backward_operator=backward_operator, - type_reconstruction=image_recon_type, - ) - ] - if registration: - if registration_simulate_reference is not None: - mri_transforms += [ - ComputeImage( - kspace_key=KspaceKey.REFERENCE_KSPACE, - target_key=TransformKey.REFERENCE_IMAGE, - backward_operator=backward_operator, - type_reconstruction=image_recon_type, - ), - SqueezeKey(keys=[TransformKey.REFERENCE_IMAGE], dim=0), - ] - if registration_simulate_reference == RegistrationSimulateReferenceType.ELASTIC: - mri_transforms += [ - RandomElasticDeformation( - image_key=TransformKey.REFERENCE_IMAGE, - target_key=TransformKey.REFERENCE_IMAGE, - use_seed=use_seed, - sigma=registration_simulate_elastic_sigma, - points=registration_simulate_elastic_points, - rotate=registration_simulate_elastic_rotate, - zoom=registration_simulate_elastic_zoom, - ) - ] - if registration_estimate_displacement: - mri_transforms += [ - Displacement( - transform_type=DisplacementTransformType.MULTISCALE_DEMONS, - demons_filter_type=demons_filter_type, - demons_num_iterations=demons_num_iterations, - demons_smooth_displacement_field=demons_smooth_displacement_field, - demons_standard_deviations=demons_standard_deviations, - demons_intensity_difference_threshold=demons_intensity_difference_threshold, - demons_maximum_rms_error=demons_maximum_rms_error, - reference_image_key=TransformKey.REFERENCE_IMAGE, - moving_image_key=registration_moving_key, - ) - ] - if delete_acs: - mri_transforms += [DeleteKeys(keys=[TransformKey.ACS_MASK, KspaceKey.ACS_KSPACE])] - if delete_kspace: - mri_transforms += [DeleteKeys(keys=[KspaceKey.KSPACE])] - - return Compose(mri_transforms) - - -class TransformsType(DirectEnum): - SUPERVISED = "supervised" - SSL_SSDU = "ssl_ssdu" - - -# pylint: disable=too-many-arguments -def build_mri_transforms( - forward_operator: Callable, - backward_operator: Callable, - mask_func: Optional[Callable], - target_acceleration: Optional[float] = None, - crop: Optional[Union[tuple[int, int], str]] = None, - crop_type: Optional[str] = "uniform", - rescale: Optional[Union[tuple[int, int], list[int]]] = None, - rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST, - rescale_2d_if_3d: Optional[bool] = False, - pad: Optional[Union[tuple[int, int], list[int]]] = None, - image_center_crop: bool = True, - random_rotation_degrees: Optional[Sequence[int]] = (-90, 90), - random_rotation_probability: float = 0.0, - random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM, - random_flip_probability: float = 0.0, - random_reverse_probability: float = 0.0, - padding_eps: float = 0.0001, - estimate_body_coil_image: bool = False, - estimate_sensitivity_maps: bool = True, - sensitivity_maps_type: SensitivityMapType = SensitivityMapType.RSS_ESTIMATE, - sensitivity_maps_gaussian: Optional[float] = None, - sensitivity_maps_espirit_threshold: Optional[float] = 0.05, - sensitivity_maps_espirit_kernel_size: Optional[int] = 6, - sensitivity_maps_espirit_crop: Optional[float] = 0.95, - sensitivity_maps_espirit_max_iters: Optional[int] = 30, - use_acs_as_mask: bool = False, - delete_acs: bool = True, - delete_kspace: bool = True, - image_recon_type: ReconstructionType = ReconstructionType.RSS, - compress_coils: Optional[int] = None, - pad_coils: Optional[int] = None, - scaling_key: TransformKey = TransformKey.MASKED_KSPACE, - scale_percentile: Optional[float] = 0.99, - registration: bool = False, - registration_simulate_reference: Optional[RegistrationSimulateReferenceType] = None, - registration_simulate_elastic_sigma: float = 3.0, - registration_simulate_elastic_points: int = 3, - registration_simulate_elastic_rotate: float = 0.0, - registration_simulate_elastic_zoom: float = 0.0, - registration_estimate_displacement: bool = True, - registration_simulate_reference_from_key_index: int = 0, - registration_moving_key: TransformKey = TransformKey.TARGET, - demons_filter_type: DemonsFilterType = DemonsFilterType.SYMMETRIC_FORCES, - demons_num_iterations: int = 100, - demons_smooth_displacement_field: bool = True, - demons_standard_deviations: float = 1.5, - demons_intensity_difference_threshold: Optional[float] = None, - demons_maximum_rms_error: Optional[float] = None, - use_seed: bool = True, - transforms_type: Optional[TransformsType] = TransformsType.SUPERVISED, - mask_split_ratio: Union[float, list[float], tuple[float, ...]] = 0.4, - mask_split_acs_region: Union[list[int], tuple[int, int]] = (0, 0), - mask_split_keep_acs: Optional[bool] = False, - mask_split_type: MaskSplitterType = MaskSplitterType.GAUSSIAN, - mask_split_gaussian_std: float = 3.0, - mask_split_half_direction: HalfSplitType = HalfSplitType.VERTICAL, -) -> DirectTransform: - r"""Build transforms for MRI. - - More specifically, the following transformations are applied: - - * Converts input to (complex-valued) tensor. - * Applies k-space (center) crop if requested. - * Applies k-space rescaling if requested. - * Applies k-space padding if requested. - * Applies random augmentations (rotation, flip, reverse) if requested. - * Adds a sampling mask if `mask_func` is defined. - * Compreses the coil dimension if requested. - * Pads the coil dimension if requested. - * Adds coil sensitivities and / or the body coil_image - * Masks the fully sampled k-space, if there is a mask function or a mask in the sample. - * Computes a scaling factor based on the masked k-space and normalizes data. - * Computes a target (image). - * Deletes the acs mask and the fully sampled k-space if requested. - * Splits the mask if requested for self-supervised learning. - - Parameters - ---------- - forward_operator : Callable - The forward operator, e.g. some form of FFT (centered or uncentered). - backward_operator : Callable - The backward operator, e.g. some form of inverse FFT (centered or uncentered). - mask_func : Callable or None - A function which creates a sampling mask of the appropriate shape. - target_acceleration : float, optional - Target acceleration factor. Default: None. - crop : tuple[int, int] or str, Optional - If not None, this will transform the "kspace" to an image domain, crop it, and transform it back. - If a tuple of integers is given then it will crop the backprojected kspace to that size. If - "reconstruction_size" is given, then it will crop the backprojected kspace according to it, but - a key "reconstruction_size" must be present in the sample. Default: None. - crop_type : Optional[str] - Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform". - rescale : tuple or list, optional - If not None, this will transform the "kspace" to the image domain, rescale it, and transform it back. - Must correspond to (height, width). This is ignored if `rescale` is None. Default: None. - It is not recommended to be used in combination with `crop`. - rescale_mode : RescaleMode - Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, - RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are - supported for 2D or 3D data. Default: RescaleMode.NEAREST. - rescale_2d_if_3d : bool, optional - If True and k-space data is 3D, rescaling will be done only on the height - and width dimensions, by combining the slice/time dimension with the batch dimension. - This is ignored if `rescale` is None. Default: False. - pad : tuple or list, optional - If not None, this will zero-pad the "kspace" to the given size. Must correspond to (height, width) - or (slice/time, height, width). Default: None. - image_center_crop : bool - If True the backprojected kspace will be cropped around the center, otherwise randomly. - This will be ignored if `crop` is None. Default: True. - random_rotation_degrees : Sequence[int], optional - Default: (-90, 90). - random_rotation_probability : float, optional - If greater than 0.0, random rotations will be applied of `random_rotation_degrees` degrees, with probability - `random_rotation_probability`. Default: 0.0. - random_flip_type : RandomFlipType, optional - Default: RandomFlipType.RANDOM. - random_flip_probability : float, optional - If greater than 0.0, random rotation of `random_flip_type` type, with probability `random_flip_probability`. - Default: 0.0. - random_reverse_probability : float - If greater than 0.0, will perform random reversion along the time or slice dimension (2) with probability - `random_reverse_probability`. Default: 0.0. - padding_eps: float - Padding epsilon. Default: 0.0001. - estimate_body_coil_image : bool - Estimate body coil image. Default: False. - estimate_sensitivity_maps : bool - Estimate sensitivity maps using the acs region. Default: True. - sensitivity_maps_type: sensitivity_maps_type - Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or SensitivityMapType.ESPIRIT. - Will be ignored if `estimate_sensitivity_maps` is False. Default: SensitivityMapType.RSS_ESTIMATE. - sensitivity_maps_gaussian : float - Optional sigma for gaussian weighting of sensitivity map. - sensitivity_maps_espirit_threshold : float, optional - Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. - Default: 0.05. - sensitivity_maps_espirit_kernel_size : int, optional - Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 6. - sensitivity_maps_espirit_crop : float, optional - Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 0.95. - sensitivity_maps_espirit_max_iters : int, optional - Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30. - use_acs_as_mask : bool - If True, will use the acs region as the mask. Default: False. - delete_acs : bool - If True will delete key `acs_mask`. Default: True. - delete_kspace : bool - If True will delete key `kspace` (fully sampled k-space). Default: True. - image_recon_type : ReconstructionType - Type to reconstruct target image. Default: ReconstructionType.RSS. - compress_coils : int, optional - Number of coils to compress input k-space. It is not recommended to be used in combination with `pad_coils`. - Default: None. - pad_coils : int - Number of coils to pad data to. - scaling_key : TransformKey - Key in sample to scale scalable items in sample. Default: TransformKey.MASKED_KSPACE. - scale_percentile : float, optional - Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99. - registration : bool - If True, will compute a displacement field between the target and the moving image. Default: False. - registration_simulate_reference : RegistrationSimulateReferenceType - If not None, will simulate a reference image for displacement field computation. Otherwise, this expects a key - in the sample. Can be RegistrationSimulateReferenceType.FROM_KEY or RegistrationSimulateReferenceType.ELASTIC. - Default: None. - registration_simulate_elastic_sigma : float - Standard deviation for the elastic simulation. Default: 3.0. - registration_simulate_elastic_points : int - Number of points for the elastic simulation. Default: 3. - registration_simulate_elastic_rotate : float - Rotation for the elastic simulation. Default: 0.0. - registration_simulate_elastic_zoom : float - Zoom for the elastic simulation. Default: 0.0. - registration_estimate_displacement : bool - If True, will estimate the displacement field between the target and the moving image using the - demons algorithm. Default: True - registration_simulate_reference_from_key_index : int - Index to drop from the key to simulate the reference image. Default: 0. - registration_moving_key : TransformKey - Key in sample to compute displacement field from. Default: TransformKey.TARGET. - demons_filter_type : DemonsFilterType - Type of filter to apply to the displacement field. Default: DemonsFilterType.SYMMETRIC_FORCES. - demons_num_iterations : int - Number of iterations for the demons algorithm. Default: 100. - demons_smooth_displacement_field : bool - If True, will smooth the displacement field. Default: True. - demons_standard_deviations : float - Standard deviation for the smoothing of the displacement field. Default: 1.5. - demons_intensity_difference_threshold : float, optional - Intensity difference threshold for the demons algorithm. Default: None. - demons_maximum_rms_error : float, optional - Maximum RMS error for the demons algorithm. Default: None. - use_seed : bool - If true, a pseudo-random number based on the filename is computed so that every slice of the volume get - the same mask every time. Default: True. - transforms_type : TransformsType, optional - Can be `TransformsType.SUPERVISED` for supervised learning transforms or `TransformsType.SSL_SSDU` for - self-supervised learning transforms. Default: `TransformsType.SUPERVISED`. - mask_split_ratio : Union[float, list[float], tuple[float, ...]] - The ratio(s) of the sampling mask splitting. If `transforms_type` is TransformsKey.SUPERVISED, this is ignored. - mask_split_acs_region : Union[list[int], tuple[int, int]] - A rectangle for the acs region that will be used in the input mask. This applies only if `transforms_type` is - set to TransformsKey.SSL_SSDU. Default: (0, 0). - mask_split_keep_acs : Optional[bool] - If True, acs region according to the "acs_mask" of the sample will be used in both mask splits. - This applies only if `transforms_type` is set to TransformsKey.SSL_SSDU. Default: False. - mask_split_type : MaskSplitterType - How the sampling mask will be split. Can be MaskSplitterType.UNIFORM, MaskSplitterType.GAUSSIAN, or - MaskSplitterType.HALF. Default: MaskSplitterType.GAUSSIAN. This applies only if `transforms_type` is - set to TransformsKey.SSL_SSDU. Default: MaskSplitterType.GAUSSIAN. - mask_split_gaussian_std : float - Standard deviation of gaussian mask splitting. This applies only if `transforms_type` is - set to TransformsKey.SSL_SSDU. Ignored if `mask_split_type` is not set to MaskSplitterType.GAUSSIAN. - Default: 3.0. - mask_split_half_direction : HalfSplitType - Split type if `mask_split_type` is `MaskSplitterType.HALF`. Can be `HalfSplitType.VERTICAL`, - `HalfSplitType.HORIZONTAL`, `HalfSplitType.DIAGONAL_LEFT` or `HalfSplitType.DIAGONAL_RIGHT`. - This applies only if `transforms_type` is set to `TransformsKey.SSL_SSDU`. Ignored if `mask_split_type` is not - set to `MaskSplitterType.HALF`. Default: `HalfSplitType.VERTICAL`. - - Returns - ------- - DirectTransform - An MRI transformation object. - """ - logger = logging.getLogger(build_mri_transforms.__name__) - logger.info("Creating %s MRI transforms.", transforms_type) - - if crop and rescale: - logger.warning( - "Rescale and crop are both given. Rescale will be applied after cropping. This is not recommended." - ) - - if compress_coils and pad_coils: - logger.warning( - "Compress coils and pad coils are both given. Compress coils will be applied before padding. " - "This is not recommended." - ) - - mri_transforms = build_supervised_mri_transforms( - forward_operator=forward_operator, - backward_operator=backward_operator, - mask_func=mask_func, - target_acceleration=target_acceleration, - crop=crop, - crop_type=crop_type, - rescale=rescale, - rescale_mode=rescale_mode, - rescale_2d_if_3d=rescale_2d_if_3d, - pad=pad, - image_center_crop=image_center_crop, - random_rotation_degrees=random_rotation_degrees, - random_rotation_probability=random_rotation_probability, - random_flip_type=random_flip_type, - random_flip_probability=random_flip_probability, - random_reverse_probability=random_reverse_probability, - padding_eps=padding_eps, - estimate_sensitivity_maps=estimate_sensitivity_maps, - sensitivity_maps_type=sensitivity_maps_type, - estimate_body_coil_image=estimate_body_coil_image, - sensitivity_maps_gaussian=sensitivity_maps_gaussian, - sensitivity_maps_espirit_threshold=sensitivity_maps_espirit_threshold, - sensitivity_maps_espirit_kernel_size=sensitivity_maps_espirit_kernel_size, - sensitivity_maps_espirit_crop=sensitivity_maps_espirit_crop, - sensitivity_maps_espirit_max_iters=sensitivity_maps_espirit_max_iters, - use_acs_as_mask=use_acs_as_mask, - delete_acs=delete_acs if transforms_type == TransformsType.SUPERVISED else False, - delete_kspace=delete_kspace if transforms_type == TransformsType.SUPERVISED else False, - image_recon_type=image_recon_type, - compress_coils=compress_coils, - pad_coils=pad_coils, - scaling_key=scaling_key, - scale_percentile=scale_percentile, - registration=registration, - registration_simulate_reference=registration_simulate_reference, - registration_simulate_elastic_sigma=registration_simulate_elastic_sigma, - registration_simulate_elastic_points=registration_simulate_elastic_points, - registration_simulate_elastic_rotate=registration_simulate_elastic_rotate, - registration_simulate_elastic_zoom=registration_simulate_elastic_zoom, - registration_estimate_displacement=registration_estimate_displacement, - registration_simulate_reference_from_key_index=registration_simulate_reference_from_key_index, - registration_moving_key=registration_moving_key, - demons_filter_type=demons_filter_type, - demons_num_iterations=demons_num_iterations, - demons_smooth_displacement_field=demons_smooth_displacement_field, - demons_standard_deviations=demons_standard_deviations, - demons_intensity_difference_threshold=demons_intensity_difference_threshold, - demons_maximum_rms_error=demons_maximum_rms_error, - use_seed=use_seed, - ).transforms - - mri_transforms += [AddBooleanKeysModule(["is_ssl"], [transforms_type != TransformsType.SUPERVISED])] - - if transforms_type == TransformsType.SUPERVISED: - return Compose(mri_transforms) - - mask_splitter_kwargs = { - "ratio": mask_split_ratio, - "acs_region": mask_split_acs_region, - "keep_acs": mask_split_keep_acs, - "use_seed": use_seed, - "kspace_key": KspaceKey.MASKED_KSPACE, - } - mri_transforms += [ - ( - GaussianMaskSplitter(**mask_splitter_kwargs, std_scale=mask_split_gaussian_std) - if mask_split_type == MaskSplitterType.GAUSSIAN - else ( - UniformMaskSplitter(**mask_splitter_kwargs) - if mask_split_type == MaskSplitterType.UNIFORM - else HalfMaskSplitterModule( - **{k: v for k, v in mask_splitter_kwargs.items() if k != "ratio"}, - direction=mask_split_half_direction, - ) - ) - ), - DeleteKeys([TransformKey.ACS_MASK]), - ] - - mri_transforms += [ - RenameKeys( - [ - SSLTransformMaskPrefixes.INPUT_ + TransformKey.MASKED_KSPACE, - SSLTransformMaskPrefixes.TARGET_ + TransformKey.MASKED_KSPACE, - ], - ["input_kspace", "kspace"], - ), - DeleteKeys(["masked_kspace", "sampling_mask"]), - ] # Rename keys for SSL engine - - mri_transforms += [ - ComputeImage( - kspace_key=KspaceKey.KSPACE, - target_key=TransformKey.TARGET, - backward_operator=backward_operator, - type_reconstruction=image_recon_type, - ) - ] - - return Compose(mri_transforms) +# Copyright (c) DIRECT Contributors + +"""The `direct.data.mri_transforms` module contains mri transformations utilized to transform or augment k-space data, +used for DIRECT's training pipeline. They can be also used individually by importing them into python scripts.""" + +from __future__ import annotations + +import contextlib +import copy +import functools +import logging +import random +import warnings +from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union + +import numpy as np +import torch + +from direct.algorithms.mri_algorithms import EspiritCalibration +from direct.data import transforms as T +from direct.exceptions import ItemNotFoundException +from direct.registration.elastic_deformation import RandomElasticDeformationModule +from direct.registration.registration import DemonsFilterType, DisplacementModule, DisplacementTransformType +from direct.ssl.ssl import ( + GaussianMaskSplitterModule, + HalfMaskSplitterModule, + HalfSplitType, + MaskSplitterType, + SSLTransformMaskPrefixes, + UniformMaskSplitterModule, +) +from direct.types import DirectEnum, IntegerListOrTupleString, KspaceKey, TransformKey +from direct.utils import DirectModule, DirectTransform +from direct.utils.asserts import assert_complex + +logger = logging.getLogger(__name__) + + +@contextlib.contextmanager +def temp_seed(rng, seed): + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +class Compose(DirectTransform): + """Compose several transformations together, for instance ClipAndScale and a flip. + + Code based on torchvision: https://github.com/pytorch/vision, but got forked from there as torchvision has some + additional dependencies. + """ + + def __init__(self, transforms: Iterable[Callable]) -> None: + """Inits :class:`Compose`. + + Parameters + ---------- + transforms: Iterable[Callable] + List of transforms. + """ + super().__init__() + self.transforms = transforms + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calls :class:`Compose`. + + Parameters + ---------- + sample: dict[str, Any] + Dict sample. + + Returns + ------- + dict[str, Any] + Dict sample transformed by `transforms`. + """ + for transform in self.transforms: + sample = transform(sample) + + return sample + + def __repr__(self): + """Representation of :class:`Compose`.""" + repr_string = self.__class__.__name__ + "(" + for transform in self.transforms: + repr_string += "\n" + repr_string += f" {transform}," + repr_string = repr_string[:-1] + "\n)" + return repr_string + + +class RandomRotation(DirectTransform): + r"""Random :math:`k`-space rotation. + + Performs a random rotation with probability :math:`p`. Rotation degrees must be multiples of 90. + """ + + def __init__( + self, + degrees: Sequence[int] = (-90, 90), + p: float = 0.5, + keys_to_rotate: tuple[TransformKey, ...] = (TransformKey.KSPACE,), + ) -> None: + r"""Inits :class:`RandomRotation`. + + Parameters + ---------- + degrees: sequence of ints + Degrees of rotation. Must be a multiple of 90. If len(degrees) > 1, then a degree will be chosen at random. + Default: (-90, 90). + p: float + Probability of rotation. Default: 0.5 + keys_to_rotate : tuple of TransformKeys + Keys to rotate. Default: "kspace". + """ + super().__init__() + + assert all(degree % 90 == 0 for degree in degrees) + + self.degrees = degrees + self.p = p + self.keys_to_rotate = keys_to_rotate + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calls :class:`RandomRotation`. + + Parameters + ---------- + sample: dict[str, Any] + Dict sample. + + Returns + ------- + dict[str, Any] + Sample with rotated values of `keys_to_rotate`. + """ + if random.SystemRandom().random() <= self.p: + degree = random.SystemRandom().choice(self.degrees) + k = degree // 90 + for key in self.keys_to_rotate: + if key in sample: + value = T.view_as_complex(sample[key].clone()) + sample[key] = T.view_as_real(torch.rot90(value, k=k, dims=(-2, -1))) + + # If rotated by multiples of (n + 1) * 90 degrees, reconstruction size also needs to change + reconstruction_size = sample.get("reconstruction_size", None) + if reconstruction_size and (k % 2) == 1: + sample["reconstruction_size"] = ( + reconstruction_size[:-3] + reconstruction_size[-3:-1][::-1] + reconstruction_size[-1:] + ) + + return sample + + +class RandomFlipType(DirectEnum): + HORIZONTAL = "horizontal" + VERTICAL = "vertical" + RANDOM = "random" + BOTH = "both" + + +class RandomFlip(DirectTransform): + r"""Random k-space flip transform. + + Performs a random flip with probability :math:`p`. Flip can be horizontal, vertical, or a random choice of the two. + """ + + def __init__( + self, + flip: RandomFlipType = RandomFlipType.RANDOM, + p: float = 0.5, + keys_to_flip: tuple[TransformKey, ...] = (TransformKey.KSPACE,), + ) -> None: + r"""Inits :class:`RandomFlip`. + + Parameters + ---------- + flip : RandomFlipType + Horizontal, vertical, or random choice of the two. Default: RandomFlipType.RANDOM. + p : float + Probability of flip. Default: 0.5 + keys_to_flip : tuple of TransformKeys + Keys to flip. Default: "kspace". + """ + super().__init__() + + self.flip = flip + self.p = p + self.keys_to_flip = keys_to_flip + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calls :class:`RandomFlip`. + + Parameters + ---------- + sample: dict[str, Any] + Dict sample. + + Returns + ------- + dict[str, Any] + Sample with flipped values of `keys_to_flip`. + """ + if random.SystemRandom().random() <= self.p: + dims = ( + (-2,) + if self.flip == "horizontal" + else ( + (-1,) + if self.flip == "vertical" + else (-2, -1) if self.flip == "both" else (random.SystemRandom().choice([-2, -1]),) + ) + ) + + for key in self.keys_to_flip: + if key in sample: + value = T.view_as_complex(sample[key].clone()) + value = torch.flip(value, dims=dims) + sample[key] = T.view_as_real(value) + + return sample + + +class RandomReverse(DirectTransform): + r"""Random reverse of the order along a given dimension of a PyTorch tensor.""" + + def __init__( + self, + dim: int = 1, + p: float = 0.5, + keys_to_reverse: tuple[TransformKey, ...] = (TransformKey.KSPACE,), + ) -> None: + r"""Inits :class:`RandomReverse`. + + Parameters + ---------- + dim : int + Dimension along to perform reversion. Typically, this is for time or slice dimension. Default: 2. + p : float + Probability of flip. Default: 0.5 + keys_to_reverse : tuple of TransformKeys + Keys to reverse. Default: "kspace". + """ + super().__init__() + + self.dim = dim + self.p = p + self.keys_to_reverse = keys_to_reverse + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calls :class:`RandomReverse`. + + Parameters + ---------- + sample: dict[str, Any] + Dict sample. + + Returns + ------- + dict[str, Any] + Sample with flipped values of `keys_to_flip`. + """ + if random.SystemRandom().random() <= self.p: + dim = self.dim + for key in self.keys_to_reverse: + if key in sample: + tensor = sample[key].clone() + + if dim < 0: + dim += tensor.dim() + + tensor = T.view_as_complex(tensor) + + index = [slice(None)] * tensor.dim() + index[dim] = torch.arange(tensor.size(dim) - 1, -1, -1, dtype=torch.long) + + tensor = tensor[tuple(index)] + + sample[key] = T.view_as_real(tensor) + + return sample + + +class CreateSamplingMask(DirectTransform): + """Data Transformer for training MRI reconstruction models. + + Creates sampling mask. + """ + + def __init__( + self, + mask_func: Callable, + shape: Optional[tuple[int, ...]] = None, + use_seed: bool = True, + return_acs: bool = False, + ) -> None: + """Inits :class:`CreateSamplingMask`. + + Parameters + ---------- + mask_func: Callable + A function which creates a sampling mask of the appropriate shape. + shape: tuple, optional + Sampling mask shape. Default: None. + use_seed: bool + If true, a pseudo-random number based on the filename is computed so that every slice of the volume get + the same mask every time. Default: True. + return_acs: bool + If True, it will generate an ACS mask. Default: False. + """ + super().__init__() + self.mask_func = mask_func + self.shape = shape + self.use_seed = use_seed + self.return_acs = return_acs + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calls :class:`CreateSamplingMask`. + + Parameters + ---------- + sample: dict[str, Any] + Dict sample. + + Returns + ------- + dict[str, Any] + Sample with `sampling_mask` key. + """ + if not self.shape: + shape = sample["kspace"].shape[1:] + elif any(_ is None for _ in self.shape): # Allow None as values. + kspace_shape = list(sample["kspace"].shape[1:-1]) + shape = tuple(_ if _ else kspace_shape[idx] for idx, _ in enumerate(self.shape)) + (2,) + else: + shape = self.shape + (2,) + + seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"]))) + + sampling_mask = self.mask_func(shape=shape, seed=seed, return_acs=False) + + if sampling_mask.ndim == 5: + acceleration = [ + np.prod(sampling_mask[0, _].shape) / sampling_mask[0, _].sum() for _ in range(sampling_mask.shape[1]) + ] + sample["acceleration"] = torch.tensor(acceleration, dtype=torch.float32).unsqueeze(0) + else: + sample["acceleration"] = (np.prod(sampling_mask.shape) / sampling_mask.sum()).unsqueeze(0) + + if "padding" in sample: + sampling_mask = T.apply_padding(sampling_mask, sample["padding"]) + + # Shape 3D: (1, 1, height, width, 1), 2D: (1, height, width, 1) + sample["sampling_mask"] = sampling_mask + + if self.return_acs: + sample["acs_mask"] = self.mask_func(shape=shape, seed=seed, return_acs=True) + if sampling_mask.ndim == 5: + center_fraction = [ + sample["acs_mask"][0, _].sum() / np.prod(sample["acs_mask"][0, _].shape) + for _ in range(sample["acs_mask"].shape[1]) + ] + sample["center_fraction"] = torch.tensor(center_fraction, dtype=torch.float32).unsqueeze(0) + else: + sample["center_fraction"] = (sample["acs_mask"].sum() / np.prod(sample["acs_mask"].shape)).unsqueeze(0) + return sample + + +class ApplyMaskModule(DirectModule): + """Data Transformer for training MRI reconstruction models. + + Masks the input k-space (with key `input_kspace_key`) using a sampling mask with key `sampling_mask_key` onto + a new masked k-space with key `target_kspace_key`. + """ + + def __init__( + self, + sampling_mask_key: str = "sampling_mask", + input_kspace_key: KspaceKey = KspaceKey.KSPACE, + target_kspace_key: KspaceKey = KspaceKey.MASKED_KSPACE, + ) -> None: + """Inits :class:`ApplyMaskModule`. + + Parameters + ---------- + sampling_mask_key: str + Default: "sampling_mask". + input_kspace_key: KspaceKey + Default: KspaceKey.KSPACE. + target_kspace_key: KspaceKey + Default KspaceKey.MASKED_KSPACE. + """ + super().__init__() + self.logger = logging.getLogger(type(self).__name__) + + self.sampling_mask_key = sampling_mask_key + self.input_kspace_key = input_kspace_key + self.target_kspace_key = target_kspace_key + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`ApplyMaskModule`. + + Applies mask with key `sampling_mask_key` onto kspace `input_kspace_key`. Result is stored as a tensor with + key `target_kspace_key`. + + Parameters + ---------- + sample: dict[str, Any] + Dict sample containing keys `sampling_mask_key` and `input_kspace_key`. + + Returns + ------- + dict[str, Any] + Sample with (new) key `target_kspace_key`. + """ + if self.input_kspace_key not in sample: + raise ValueError(f"Key {self.input_kspace_key} corresponding to `input_kspace_key` not found in sample.") + input_kspace = sample[self.input_kspace_key] + + if self.sampling_mask_key not in sample: + raise ValueError(f"Key {self.sampling_mask_key} corresponding to `sampling_mask_key` not found in sample.") + sampling_mask = sample[self.sampling_mask_key] + + target_kspace, _ = T.apply_mask(input_kspace, sampling_mask) + sample[self.target_kspace_key] = target_kspace + return sample + + +class CropKspace(DirectTransform): + """Data Transformer for training MRI reconstruction models. + + Crops the k-space by: + * It first projects the k-space to the image-domain via the backward operator, + * It crops the back-projected k-space to specified shape or key, + * It transforms the cropped back-projected k-space to the k-space domain via the forward operator. + """ + + def __init__( + self, + crop: Union[str, tuple[int, ...], list[int]], + forward_operator: Callable = T.fft2, + backward_operator: Callable = T.ifft2, + image_space_center_crop: bool = False, + random_crop_sampler_type: Optional[str] = "uniform", + random_crop_sampler_use_seed: Optional[bool] = True, + random_crop_sampler_gaussian_sigma: Optional[list[float]] = None, + ) -> None: + """Inits :class:`CropKspace`. + + Parameters + ---------- + crop: tuple of ints or str + Shape to crop the input to or a string pointing to a crop key (e.g. `reconstruction_size`). + forward_operator: Callable + The forward operator, e.g. some form of FFT (centered or uncentered). + Default: :class:`direct.data.transforms.fft2`. + backward_operator: Callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + Default: :class:`direct.data.transforms.ifft2`. + image_space_center_crop: bool + If set, the crop in the data will be taken in the center + random_crop_sampler_type: Optional[str] + If "uniform" the random cropping will be done by uniformly sampling `crop`, as opposed to `gaussian` which + will sample from a gaussian distribution. If `image_space_center_crop` is True, then this is ignored. + Default: "uniform". + random_crop_sampler_use_seed: bool + If true, a pseudo-random number based on the filename is computed so that every slice of the volume + is cropped the same way. Default: True. + random_crop_sampler_gaussian_sigma: Optional[list[float]] + Standard variance of the gaussian when `random_crop_sampler_type` is `gaussian`. + If `image_space_center_crop` is True, then this is ignored. Default: None. + """ + super().__init__() + self.logger = logging.getLogger(type(self).__name__) + + self.image_space_center_crop = image_space_center_crop + + if not (isinstance(crop, (Iterable, str))): + raise ValueError( + f"Invalid input for `crop`. Received {crop}. Can be a list of tuple of integers or a string." + ) + self.crop = crop + + if image_space_center_crop: + self.crop_func = T.complex_center_crop + else: + self.crop_func = functools.partial( + T.complex_random_crop, + sampler=random_crop_sampler_type, + sigma=random_crop_sampler_gaussian_sigma, + ) + self.random_crop_sampler_use_seed = random_crop_sampler_use_seed + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calls :class:`CropKspace`. + + Parameters + ---------- + sample: dict[str, Any] + Dict sample containing key `kspace`. + + Returns + ------- + dict[str, Any] + Cropped and masked sample. + """ + + kspace = sample["kspace"] # shape (coil, [slice/time], height, width, complex=2) + + dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D + + backprojected_kspace = self.backward_operator(kspace, dim=dim) # shape (coil, height, width, complex=2) + + if isinstance(self.crop, IntegerListOrTupleString): + crop_shape = IntegerListOrTupleString(self.crop) + elif isinstance(self.crop, str): + assert self.crop in sample, f"Not found {self.crop} key in sample." + crop_shape = sample[self.crop][:-1] + else: + if kspace.ndim == 5 and len(self.crop) == 2: + crop_shape = (kspace.shape[1],) + tuple(self.crop) + else: + crop_shape = tuple(self.crop) + + cropper_args = { + "data_list": [backprojected_kspace], + "crop_shape": crop_shape, + "contiguous": False, + } + if not self.image_space_center_crop: + cropper_args["seed"] = ( + None if not self.random_crop_sampler_use_seed else tuple(map(ord, str(sample["filename"]))) + ) + cropped_backprojected_kspace = self.crop_func(**cropper_args) + + if "sampling_mask" in sample: + sample["sampling_mask"] = T.complex_center_crop( + sample["sampling_mask"], (1,) + tuple(crop_shape)[1:] if kspace.ndim == 5 else crop_shape + ) + sample["acs_mask"] = T.complex_center_crop( + sample["acs_mask"], (1,) + tuple(crop_shape)[1:] if kspace.ndim == 5 else crop_shape + ) + + # Compute new k-space for the cropped_backprojected_kspace + # shape (coil, [slice/time], new_height, new_width, complex=2) + sample["kspace"] = self.forward_operator(cropped_backprojected_kspace, dim=dim) # The cropped kspace + + return sample + + +class RescaleMode(DirectEnum): + AREA = "area" + BICUBIC = "bicubic" + BILINEAR = "bilinear" + NEAREST = "nearest" + NEAREST_EXACT = "nearest-exact" + TRILINEAR = "trilinear" + + +class RescaleKspace(DirectTransform): + """Rescale k-space (downsample/upsample) module. + + Rescales the k-space: + * It first projects the k-space to the image-domain via the backward operator, + * It rescales the back-projected k-space to specified shape, + * It transforms the rescaled back-projected k-space to the k-space domain via the forward operator. + + Parameters + ---------- + shape : tuple or list of ints + Shape to rescale the input. Must be correspond to (height, width). + forward_operator : Callable + The forward operator, e.g. some form of FFT (centered or uncentered). + Default: :class:`direct.data.transforms.fft2`. + backward_operator : Callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + Default: :class:`direct.data.transforms.ifft2`. + rescale_mode : RescaleMode + Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, + RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are + supported for 2D or 3D data. Default: RescaleMode.NEAREST. + kspace_key : KspaceKey + K-space key. Default: KspaceKey.KSPACE. + rescale_2d_if_3d : bool, optional + If True and input k-space data is 3D, rescaling will be done only on the height and width dimensions. + Default: False. + + Note + ---- + If the input k-space data is 3D, rescaling will be done only on the height and width dimensions if + `rescale_2d_if_3d` is set to True. + """ + + def __init__( + self, + shape: Union[tuple[int, int], list[int]], + forward_operator: Callable = T.fft2, + backward_operator: Callable = T.ifft2, + rescale_mode: RescaleMode = RescaleMode.NEAREST, + kspace_key: KspaceKey = KspaceKey.KSPACE, + rescale_2d_if_3d: Optional[bool] = None, + ) -> None: + """Inits :class:`RescaleKspace`. + + Parameters + ---------- + shape : tuple or list of ints + Shape to rescale the input. Must be correspond to (height, width). + forward_operator : Callable + The forward operator, e.g. some form of FFT (centered or uncentered). + Default: :class:`direct.data.transforms.fft2`. + backward_operator : Callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + Default: :class:`direct.data.transforms.ifft2`. + rescale_mode : RescaleMode + Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, + RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are + supported for 2D or 3D data. Default: RescaleMode.NEAREST. + kspace_key : KspaceKey + K-space key. Default: KspaceKey.KSPACE. + rescale_2d_if_3d : bool, optional + If True and input k-space data is 3D, rescaling will be done only on the height and width dimensions, + by combining the slice/time dimension with the batch dimension. + Default: False. + """ + super().__init__() + self.logger = logging.getLogger(type(self).__name__) + + if len(shape) not in [2, 3]: + raise ValueError( + f"Shape should be a list or tuple of two integers if 2D or three integers if 3D. " + f"Received: {shape}." + ) + self.shape = shape + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.rescale_mode = rescale_mode + self.kspace_key = kspace_key + + self.rescale_2d_if_3d = rescale_2d_if_3d + if rescale_2d_if_3d and len(shape) == 3: + raise ValueError("Shape cannot have a length of 3 when rescale_2d_if_3d is set to True.") + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calls :class:`RescaleKspace`. + + Parameters + ---------- + sample: Dict[str, Any] + Dict sample containing key `kspace`. + + Returns + ------- + Dict[str, Any] + Cropped and masked sample. + """ + kspace = sample[self.kspace_key] # shape (coil, [slice/time], height, width, complex=2) + + dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D + + backprojected_kspace = self.backward_operator(kspace, dim=dim) + + if kspace.ndim == 5 and self.rescale_2d_if_3d: + backprojected_kspace = backprojected_kspace.permute(1, 0, 2, 3, 4) + + if (kspace.ndim == 4) or (kspace.ndim == 5 and not self.rescale_2d_if_3d): + backprojected_kspace = backprojected_kspace.unsqueeze(0) + + rescaled_backprojected_kspace = T.complex_image_resize(backprojected_kspace, self.shape, self.rescale_mode) + + if (kspace.ndim == 4) or (kspace.ndim == 5 and not self.rescale_2d_if_3d): + rescaled_backprojected_kspace = rescaled_backprojected_kspace.squeeze(0) + + if kspace.ndim == 5 and self.rescale_2d_if_3d: + rescaled_backprojected_kspace = rescaled_backprojected_kspace.permute(1, 0, 2, 3, 4) + + # Compute new k-space from rescaled_backprojected_kspace + # shape (coil, [slice/time if rescale_2d_if_3d else new_slc_or_time], new_height, new_width, complex=2) + sample[self.kspace_key] = self.forward_operator(rescaled_backprojected_kspace, dim=dim) # The rescaled kspace + + return sample + + +class PadKspace(DirectTransform): + """Pad k-space with zeros to desired shape module. + + Rescales the k-space by: + * It first projects the k-space to the image-domain via the backward operator, + * It pads the back-projected k-space to specified shape, + * It transforms the rescaled back-projected k-space to the k-space domain via the forward operator. + + Parameters + ---------- + pad_shape : tuple or list of ints + Shape to zero-pad the input. Must be correspond to (height, width) or (slice/time, height, width). + forward_operator : Callable + The forward operator, e.g. some form of FFT (centered or uncentered). + Default: :class:`direct.data.transforms.fft2`. + backward_operator : Callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + Default: :class:`direct.data.transforms.ifft2`. + kspace_key : KspaceKey + K-space key. Default: KspaceKey.KSPACE. + """ + + def __init__( + self, + pad_shape: Union[tuple[int, ...], list[int]], + forward_operator: Callable = T.fft2, + backward_operator: Callable = T.ifft2, + kspace_key: KspaceKey = KspaceKey.KSPACE, + ) -> None: + """Inits :class:`RescaleKspace`. + + Parameters + ---------- + pad_shape : tuple or list of ints + Shape to zero-pad the input. Must be correspond to (height, width) or (slice/time, height, width). + forward_operator : Callable + The forward operator, e.g. some form of FFT (centered or uncentered). + Default: :class:`direct.data.transforms.fft2`. + backward_operator : Callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + Default: :class:`direct.data.transforms.ifft2`. + kspace_key : KspaceKey + K-space key. Default: KspaceKey.KSPACE. + """ + super().__init__() + self.logger = logging.getLogger(type(self).__name__) + + if len(pad_shape) not in [2, 3]: + raise ValueError(f"Shape should be a list or tuple of two or three integers. Received: {pad_shape}.") + + self.shape = pad_shape + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.kspace_key = kspace_key + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calls :class:`PadKspace`. + + Parameters + ---------- + sample: dict[str, Any] + Dict sample containing key `kspace`. + + Returns + ------- + dict[str, Any] + Cropped and masked sample. + """ + kspace = sample[self.kspace_key] # shape (coil, [slice or time], height, width, complex=2) + shape = kspace.shape + + sample["original_size"] = shape[1:-1] + + dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D + + backprojected_kspace = self.backward_operator(kspace, dim=dim) + backprojected_kspace = T.view_as_complex(backprojected_kspace) + + padded_backprojected_kspace = T.pad_tensor(backprojected_kspace, self.shape) + padded_backprojected_kspace = T.view_as_real(padded_backprojected_kspace) + + # shape (coil, [slice or time], height, width, complex=2) + sample[self.kspace_key] = self.forward_operator(padded_backprojected_kspace, dim=dim) # The padded kspace + + return sample + + +class ComputeZeroPadding(DirectTransform): + r"""Computes zero padding present in multi-coil kspace input. + + Zero-padding is computed from multi-coil kspace with no signal contribution, i.e. its magnitude + is really close to zero: + + .. math :: + + \text{padding} = \sum_{i=1}^{n_c} |y_i| < \frac{1}{n_x \cdot n_y} + \sum_{j=1}^{n_x \cdot n_y} \big\{\sum_{i=1}^{n_c} |y_i|\big\} * \epsilon. + """ + + def __init__( + self, + kspace_key: KspaceKey = KspaceKey.KSPACE, + padding_key: str = "padding", + eps: Optional[float] = 0.0001, + ) -> None: + """Inits :class:`ComputeZeroPadding`. + + Parameters + ---------- + kspace_key: KspaceKey + K-space key. Default: KspaceKey.KSPACE. + padding_key: str + Target key. Default: "padding". + eps: float + Epsilon to multiply sum of signals. If really high, probably no padding will be produced. Default: 0.0001. + """ + super().__init__() + self.kspace_key = kspace_key + self.padding_key = padding_key + self.eps = eps + + def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]: + """Updates sample with a key `padding_key` with value a binary tensor. + + Non-zero entries indicate samples in kspace with key `kspace_key` which have minor contribution, i.e. padding. + + Parameters + ---------- + sample : dict[str, Any] + Dict sample containing key `kspace_key`. + coil_dim : int + Coil dimension. Default: 0. + + Returns + ------- + sample : dict[str, Any] + Dict sample containing key `padding_key`. + """ + if self.eps is None: + return sample + shape = sample[self.kspace_key].shape + + kspace = T.modulus(sample[self.kspace_key].clone()).sum(coil_dim) + + if len(shape) == 5: # Check if 3D data + # Assumes that slice dim is 0 + kspace = kspace.sum(0) + + padding = (kspace < (torch.mean(kspace) * self.eps)).to(kspace.device) + + if len(shape) == 5: + padding = padding.unsqueeze(0) + + padding = padding.unsqueeze(coil_dim).unsqueeze(-1) + sample[self.padding_key] = padding + + return sample + + +class ApplyZeroPadding(DirectTransform): + """Applies zero padding present in multi-coil kspace input.""" + + def __init__(self, kspace_key: KspaceKey = KspaceKey.KSPACE, padding_key: str = "padding") -> None: + """Inits :class:`ApplyZeroPadding`. + + Parameters + ---------- + kspace_key: KspaceKey + K-space key. Default: KspaceKey.KSPACE. + padding_key: str + Target key. Default: "padding". + """ + super().__init__() + self.kspace_key = kspace_key + self.padding_key = padding_key + + def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]: + """Applies zero padding on `kspace_key` with value a binary tensor. + + Parameters + ---------- + sample : dict[str, Any] + Dict sample containing key `kspace_key`. + coil_dim : int + Coil dimension. Default: 0. + + Returns + ------- + sample : dict[str, Any] + Dict sample containing key `padding_key`. + """ + + sample[self.kspace_key] = T.apply_padding(sample[self.kspace_key], sample[self.padding_key]) + + return sample + + +class ReconstructionType(DirectEnum): + """Reconstruction method for :class:`ComputeImage` transform.""" + + IFFT = "ifft" + RSS = "rss" + COMPLEX = "complex" + COMPLEX_MOD = "complex_mod" + SENSE = "sense" + SENSE_MOD = "sense_mod" + + +class ComputeImageModule(DirectModule): + """Compute Image transform.""" + + def __init__( + self, + kspace_key: KspaceKey, + target_key: str, + backward_operator: Callable, + type_reconstruction: ReconstructionType = ReconstructionType.RSS, + ) -> None: + """Inits :class:`ComputeImageModule`. + + Parameters + ---------- + kspace_key: KspaceKey + K-space key. + target_key: str + Target key. + backward_operator: callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + type_reconstruction: ReconstructionType + Type of reconstruction. Can be ReconstructionType.RSS, ReconstructionType.COMPLEX, + ReconstructionType.COMPLEX_MOD, ReconstructionType.SENSE, ReconstructionType.SENSE_MOD or + ReconstructionType.IFFT. Default: ReconstructionType.RSS. + """ + super().__init__() + self.backward_operator = backward_operator + self.kspace_key = kspace_key + self.target_key = target_key + self.type_reconstruction = type_reconstruction + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`ComputeImageModule`. + + Parameters + ---------- + sample: dict[str, Any] + Contains key kspace_key with value a torch.Tensor of shape (coil,\*spatial_dims, complex=2). + + Returns + ------- + sample: dict + Contains key target_key with value a torch.Tensor of shape (\*spatial_dims) if `type_reconstruction` is + ReconstructionType.RSS, ReconstructionType.COMPLEX_MOD, ReconstructionType.SENSE_MOD, + and of shape (\*spatial_dims, complex_dim=2) otherwise. + """ + kspace_data = sample[self.kspace_key] + dim = self.spatial_dims.TWO_D if kspace_data.ndim == 5 else self.spatial_dims.THREE_D + # Get complex-valued data solution + image = self.backward_operator(kspace_data, dim=dim) + if self.type_reconstruction == ReconstructionType.IFFT: + sample[self.target_key] = image + elif self.type_reconstruction in [ + ReconstructionType.COMPLEX, + ReconstructionType.COMPLEX_MOD, + ]: + sample[self.target_key] = image.sum(self.coil_dim) + elif self.type_reconstruction == ReconstructionType.RSS: + sample[self.target_key] = T.root_sum_of_squares(image, dim=self.coil_dim) + else: + if "sensitivity_map" not in sample: + raise ItemNotFoundException( + "sensitivity map", + "Sensitivity map is required for SENSE reconstruction.", + ) + sample[self.target_key] = T.complex_multiplication(T.conjugate(sample["sensitivity_map"]), image).sum( + self.coil_dim + ) + if self.type_reconstruction in [ + ReconstructionType.COMPLEX_MOD, + ReconstructionType.SENSE_MOD, + ]: + sample[self.target_key] = T.modulus(sample[self.target_key], self.complex_dim) + return sample + + +class EstimateBodyCoilImage(DirectTransform): + """Estimates body coil image.""" + + def __init__(self, mask_func: Callable, backward_operator: Callable, use_seed: bool = True) -> None: + """Inits :class:`EstimateBodyCoilImage'. + + Parameters + ---------- + mask_func: Callable + A function which creates a sampling mask of the appropriate shape. + backward_operator: callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + use_seed: bool + If true, a pseudo-random number based on the filename is computed so that every slice of the volume get + the same mask every time. Default: True. + """ + super().__init__() + self.mask_func = mask_func + self.use_seed = use_seed + self.backward_operator = backward_operator + + def __call__(self, sample: dict[str, Any], coil_dim: int = 0) -> dict[str, Any]: + """Calls :class:`EstimateBodyCoilImage`. + + Parameters + ---------- + sample: dict[str, Any] + Contains key kspace_key with value a torch.Tensor of shape (coil, ..., complex=2). + coil_dim: int + Coil dimension. Default: 0. + + Returns + ---------- + sample: dict[str, Any] + Contains key `"body_coil_image`. + """ + kspace = sample["kspace"] + + # We need to create an ACS mask based on the shape of this kspace, as it can be cropped. + seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"]))) + kspace_shape = tuple(sample["kspace"].shape[-3:]) + acs_mask = self.mask_func(shape=kspace_shape, seed=seed, return_acs=True) + + kspace = acs_mask * kspace + 0.0 + dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D + acs_image = self.backward_operator(kspace, dim=dim) + + sample["body_coil_image"] = T.root_sum_of_squares(acs_image, dim=coil_dim) + return sample + + +class SensitivityMapType(DirectEnum): + ESPIRIT = "espirit" + RSS_ESTIMATE = "rss_estimate" + UNIT = "unit" + + +class EstimateSensitivityMapModule(DirectModule): + """Data Transformer for training MRI reconstruction models. + + Estimates sensitivity maps given masked k-space data using one of three methods: + + * Unit: unit sensitivity map in case of single coil acquisition. + * RSS-estimate: sensitivity maps estimated by using the root-sum-of-squares of the autocalibration-signal. + * ESPIRIT: sensitivity maps estimated with the ESPIRIT method [1]_. Note that this is currently not + implemented for 3D data, and attempting to use it in such cases will result in a NotImplementedError. + + References + ---------- + + .. [1] Uecker M, Lai P, Murphy MJ, Virtue P, Elad M, Pauly JM, Vasanawala SS, Lustig M. ESPIRiT--an eigenvalue + approach to autocalibrating parallel MRI: where SENSE meets GRAPPA. Magn Reson Med. 2014 Mar;71(3):990-1001. + doi: 10.1002/mrm.24751. PMID: 23649942; PMCID: PMC4142121. + """ + + def __init__( + self, + kspace_key: KspaceKey = KspaceKey.ACS_KSPACE, + backward_operator: Callable = T.ifft2, + type_of_map: Optional[SensitivityMapType] = SensitivityMapType.RSS_ESTIMATE, + gaussian_sigma: Optional[float] = None, + espirit_threshold: Optional[float] = 0.05, + espirit_kernel_size: Optional[int] = 6, + espirit_crop: Optional[float] = 0.95, + espirit_max_iters: Optional[int] = 30, + ) -> None: + """Inits :class:`EstimateSensitivityMapModule`. + + Parameters + ---------- + kspace_key: KspaceKey + K-space key to compute the ACS image from. If `kspace_key` is not `KspaceKey.ACS_KSPACE`, + the ACS mask should be provided in the sample. Default: KspaceKey.ACS_KSPACE. + backward_operator: callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + type_of_map: SensitivityMapType, optional + Type of map to estimate. Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or + SensitivityMapType.ESPIRIT. Default: SensitivityMapType.RSS_ESTIMATE. + gaussian_sigma: float, optional + If non-zero, acs_image well be calculated + espirit_threshold: float, optional + Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. + Default: 0.05. + espirit_kernel_size: int, optional + Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. + Default: 6. + espirit_crop: float, optional + Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. + Default: 0.95. + espirit_max_iters: int, optional + Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30. + """ + super().__init__() + self.backward_operator = backward_operator + self.kspace_key = kspace_key + self.type_of_map = type_of_map + + # RSS estimate attributes + self.gaussian_sigma = gaussian_sigma + # Espirit attributes + if type_of_map == SensitivityMapType.ESPIRIT: + self.espirit_calibrator = EspiritCalibration( + backward_operator, + espirit_threshold, + espirit_kernel_size, + espirit_crop, + espirit_max_iters, + kspace_key, + ) + self.espirit_threshold = espirit_threshold + self.espirit_kernel_size = espirit_kernel_size + self.espirit_crop = espirit_crop + self.espirit_max_iters = espirit_max_iters + + def estimate_acs_image(self, sample: dict[str, Any], width_dim: int = -2) -> torch.Tensor: + """Estimates the autocalibration (ACS) image by sampling the k-space using the ACS mask. + + Parameters + ---------- + sample: dict[str, Any] + Sample dictionary, + width_dim: int + Dimension corresponding to width. Default: -2. + + Returns + ------- + acs_image: torch.Tensor + Estimate of the ACS image. + """ + kspace_data = sample[self.kspace_key] + + if self.kspace_key != KspaceKey.ACS_KSPACE: + if TransformKey.ACS_MASK not in sample: + raise ValueError("ACS mask is required for estimating ACS image from k-space but not found.") + kspace_data = kspace_data * sample[TransformKey.ACS_MASK] + + if self.gaussian_sigma == 0 or not self.gaussian_sigma: + kspace_acs = kspace_data + 0.0 # + 0.0 removes the sign of zeros. + else: + gaussian_mask = torch.linspace(-1, 1, kspace_data.size(width_dim), dtype=kspace_data.dtype) + gaussian_mask = torch.exp(-((gaussian_mask / self.gaussian_sigma) ** 2)) + gaussian_mask_shape = torch.ones(len(kspace_data.shape)).int() + gaussian_mask_shape[width_dim] = kspace_data.size(width_dim) + gaussian_mask = gaussian_mask.reshape(tuple(gaussian_mask_shape)) + kspace_acs = kspace_data * gaussian_mask + 0.0 + + # Get complex-valued data solution + # Shape (batch, [slice/time], coil, height, width, complex=2) + dim = self.spatial_dims.TWO_D if kspace_data.ndim == 5 else self.spatial_dims.THREE_D + acs_image = self.backward_operator(kspace_acs, dim=dim) + + return acs_image + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calculates sensitivity maps for the input sample. + + Parameters + ---------- + sample: dict[str, Any] + Must contain key matching kspace_key with value a (complex) torch.Tensor + of shape (coil, height, width, complex=2). + + Returns + ------- + sample: dict[str, Any] + Sample with key "sensitivity_map" with value the estimated sensitivity map. + """ + kspace = sample[self.kspace_key] # shape (batch, coil, [slice/time], height, width, complex=2) + + if kspace.shape[self.coil_dim] == 1: + warnings.warn( + "Estimation of sensitivity map of Single-coil data. This warning will be displayed only once." + ) + if "sensitivity_map" in sample: + warnings.warn( + "`sensitivity_map` is given, but will be overwritten. This warning will be displayed only once." + ) + + if self.type_of_map == SensitivityMapType.UNIT: + sensitivity_map = torch.zeros(kspace.shape).float() + # Assumes complex channel is last + assert_complex(kspace, complex_last=True) + sensitivity_map[..., 0] = 1.0 + # Shape (batch, coil, [slice/time], height, width, complex=2) + sensitivity_map = sensitivity_map.to(kspace.device) + + elif self.type_of_map == SensitivityMapType.RSS_ESTIMATE: + # Shape (batch, coil, [slice/time], height, width, complex=2) + acs_image = self.estimate_acs_image(sample) + # Shape (batch, [slice/time], height, width) + acs_image_rss = T.root_sum_of_squares(acs_image, dim=self.coil_dim) + # Shape (batch, 1, [slice/time], height, width, 1) + acs_image_rss = acs_image_rss.unsqueeze(self.coil_dim).unsqueeze(self.complex_dim) + # Shape (batch, coil, [slice/time], height, width, complex=2) + sensitivity_map = T.safe_divide(acs_image, acs_image_rss) + else: + if sample[self.kspace_key].ndim > 5: + raise NotImplementedError( + "EstimateSensitivityMapModule is not yet implemented for " + "Espirit sensitivity map estimation for 3D data." + ) + sensitivity_map = self.espirit_calibrator(sample) + + sensitivity_map_norm = torch.sqrt( + (sensitivity_map**2).sum(self.complex_dim).sum(self.coil_dim) + ) # shape (batch, [slice/time], height, width) + sensitivity_map_norm = sensitivity_map_norm.unsqueeze(self.coil_dim).unsqueeze(self.complex_dim) + + sample[TransformKey.SENSITIVITY_MAP] = T.safe_divide(sensitivity_map, sensitivity_map_norm) + return sample + + +class AddBooleanKeysModule(DirectModule): + """Adds keys with boolean values to sample.""" + + def __init__(self, keys: list[str], values: list[bool]) -> None: + """Inits :class:`AddBooleanKeysModule`. + + Parameters + ---------- + keys : list[str] + A list of keys to be added. + values : list[bool] + A list of values corresponding to the keys. + """ + super().__init__() + self.keys = keys + self.values = values + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Adds boolean keys to the input sample dictionary. + + Parameters + ---------- + sample : dict[str, Any] + The input sample dictionary. + + Returns + ------- + dict[str, Any] + The modified sample with added boolean keys. + """ + for key, value in zip(self.keys, self.values): + sample[key] = value + + return sample + + +class CopyKeysModule(DirectModule): + """Copy keys to a new name from the sample if present.""" + + def __init__(self, keys: list[str], new_keys: list[str]) -> None: + """Inits :class:`CopyKeysModule`. + + Parameters + ---------- + keys: List[str] + Key(s) to copy. + new_keys: List[str] + Key(s) to create. + """ + super().__init__() + self.keys = keys + self.new_keys = new_keys + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`CopyKeysModule`. + + Parameters + ---------- + sample: Dict[str, Any] + Dictionary to look for keys and copy them with a new name. + + Returns + ------- + Dict[str, Any] + Dictionary with copied specified keys. + """ + for key, new_key in zip(self.keys, self.new_keys): + if key in sample: + if isinstance(sample[key], np.ndarray): + sample[new_key] = sample[key].copy() # Copy NumPy array + elif isinstance(sample[key], torch.Tensor): + sample[new_key] = sample[key].detach().clone() # Copy Torch tensor + else: + sample[new_key] = copy.deepcopy(sample[key]) + return sample + + +class CompressCoilModule(DirectModule): + """Compresses k-space coils using SVD.""" + + def __init__(self, kspace_key: KspaceKey, num_coils: int) -> None: + """Inits :class:`CompressCoilModule`. + + Parameters + ---------- + kspace_key : KspaceKey + K-space key. + num_coils : int + Number of coils to compress. + """ + super().__init__() + self.kspace_key = kspace_key + self.num_coils = num_coils + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Performs coil compression to input k-space. + + Parameters + ---------- + sample : dict[str, Any] + Dict sample containing key `kspace_key`. Assumes coil dimension is first axis. + + Returns + ------- + sample : dict[str, Any] + Dict sample with `kspace_key` compressed to num_coils. + """ + k_space = sample[self.kspace_key].clone() # shape (batch, coil, [slice/time], height, width, complex=2) + + if k_space.shape[1] <= self.num_coils: + return sample + + ndim = k_space.ndim + + k_space = torch.view_as_complex(k_space) + + if ndim == 6: # If 3D sample reshape slice into batch dimension as sensitivities are computed 2D + num_slice_or_time = k_space.shape[2] + k_space = k_space.permute(0, 2, 1, 3, 4) + k_space = k_space.reshape(k_space.shape[0] * num_slice_or_time, *k_space.shape[2:]) + + shape = k_space.shape + + # Reshape the k-space data to combine spatial dimensions + k_space_reshaped = k_space.reshape(shape[0], shape[1], -1) + + # Compute the coil combination matrix using Singular Value Decomposition (SVD) + U, _, _ = torch.linalg.svd(k_space_reshaped, full_matrices=False) + + # Select the top ncoils_new singular vectors from the decomposition + U_new = U[:, :, : self.num_coils] + + # Perform coil compression + compressed_k_space = torch.matmul(U_new.transpose(1, 2), k_space_reshaped) + + # Reshape the compressed k-space back to its original shape + compressed_k_space = compressed_k_space.reshape(shape[0], self.num_coils, *shape[2:]) + + if ndim == 6: + compressed_k_space = compressed_k_space.reshape( + shape[0] // num_slice_or_time, num_slice_or_time, self.num_coils, *shape[2:] + ).permute(0, 2, 1, 3, 4) + + compressed_k_space = torch.view_as_real(compressed_k_space) + sample[self.kspace_key] = compressed_k_space # shape (batch, new coil, [slice/time], height, width, complex=2) + + return sample + + +class DeleteKeysModule(DirectModule): + """Remove keys from the sample if present.""" + + def __init__(self, keys: list[str]) -> None: + """Inits :class:`DeleteKeys`. + + Parameters + ---------- + keys: list[str] + Key(s) to delete. + """ + super().__init__() + self.keys = keys + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`DeleteKeys`. + + Parameters + ---------- + sample: dict[str, Any] + Dictionary to look for keys and remove them. + + Returns + ------- + dict[str, Any] + Dictionary with deleted specified keys. + """ + for key in self.keys: + if key in sample: + del sample[key] + + return sample + + +class IndexSelectionMode(DirectEnum): + RANDOM = "random" + CUSTOM = "custom" + RANGE = "range" + + +class IndexSelectionModule(DirectModule): + """Randomly selects indices from the sample. + + Parameters + ---------- + key: TransformKey + Key to select indices from. + mode: IndexSelectionMode + Mode of index selection. Can be IndexSelectionMode.RANDOM, IndexSelectionMode.CUSTOM or + IndexSelectionMode.RANGE. Default: IndexSelectionMode.CUSTOM. + num_indices: int + Number of indices to select. + out_key: TransformKey, optional + Key to store the selected indices. If None, the indices are stored in the same key. + Default: None. + index_dim: int + Dimension along which to select indices. Default: 1. + use_seed: bool + If true, a pseudo-random number based on the filename is computed so that every slice of the volume get + the same mask every time. Default: True + """ + + def __init__( + self, + key: TransformKey, + mode: IndexSelectionMode = IndexSelectionMode.CUSTOM, + indices: Optional[list[int]] = None, + num_indices: Optional[int] = None, + out_key: Optional[TransformKey] = None, + index_dim: int = 0, + use_seed: bool = True, + ) -> None: + """Inits :class:`IndexSelection`. + + Parameters + ---------- + key: TransformKey + Key to select indices from. + mode: IndexSelectionMode + Mode of index selection. Can be IndexSelectionMode.RANDOM, IndexSelectionMode.CUSTOM or + IndexSelectionMode.RANGE. Default: IndexSelectionMode.CUSTOM. + indices: list[int], optional + List of indices to select if mode is IndexSelectionMode.CUSTOM or range if mode is + IndexSelectionMode.RANGE. Default: None. + num_indices: int + Number of indices to select if mode is IndexSelectionMode.RANDOM. Default: None. + out_key: TransformKey, optional + Key to store the selected indices. If None, the indices are stored in the same key. + Default: None. + index_dim: int + Dimension along which to select indices. Default: 1. + use_seed: bool + If true, a pseudo-random number based on the filename is computed so that every slice of the volume get + the same mask every time. Default: True + """ + super().__init__() + self.key = key + self.out_key = out_key if out_key is not None else key + self.mode = mode + self.indices = indices + self.num_indices = num_indices + self.index_dim = index_dim + self.use_seed = use_seed + self.rng = np.random.RandomState() + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`IndexSelection`. + + Parameters + ---------- + sample: dict[str, Any] + Dictionary to look for key and select indices from. + + Returns + ------- + dict[str, Any] + Dictionary with randomly selected indices. + """ + if self.key not in sample: + return sample + + if self.mode == IndexSelectionMode.RANDOM: + seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"]))) + with temp_seed(self.rng, seed): + num_to_keep = max(min(self.num_indices, sample[self.key].shape[self.index_dim]), 1) + start = self.rng.randint(0, sample[self.key].shape[self.index_dim] - num_to_keep) + keep_indices = torch.arange(start, start + num_to_keep, device=sample[self.key].device) + else: + if self.mode == IndexSelectionMode.CUSTOM: + keep_indices = torch.tensor( + [idx for idx in self.indices if np.abs(idx) < sample[self.key].shape[self.index_dim]], + device=sample[self.key].device, + ) + else: + keep_indices = torch.arange(self.indices[0], self.indices[1], device=sample[self.key].device) + num_to_keep = len(keep_indices) + + sample[self.out_key] = sample[self.key].index_select(self.index_dim, keep_indices) + + if num_to_keep == 1: + sample[self.out_key] = sample[self.out_key].squeeze(self.index_dim) + + return sample + + +class DropIndexModule(DirectModule): + """Drop indices from the sample. + + Parameters + ---------- + keys: list[TransformKey] + Key(s) to drop indices from. + index: int + Index to drop. + index_dim: int, list[int] + Dimension(s) along which to drop indices. If a list, must have the same length as `keys`. Default: 1. + store_deleted_keys: list[TransformKey], optional + Key(s) to store the deleted indices. If None, the deleted indices are not stored. If the length does not + match `keys`, the remaining keys are set to None. Default: None. + """ + + def __init__( + self, + keys: list[TransformKey], + index: int, + index_dim: int | list[int] = 1, + store_deleted_keys: Optional[list[TransformKey]] = None, + ) -> None: + """Inits :class:`DropIndexModule`. + + Parameters + ---------- + keys: list[TransformKey] + Key(s) to drop indices from. + index: int + Index to drop. + index_dim: int, list[int] + Dimension(s) along which to drop indices. If a list, must have the same length as `keys`. Default: 1. + store_deleted_keys: list[TransformKey], optional + Key(s) to store the deleted indices. If None, the deleted indices are not stored. If the length does not + match `keys`, the remaining keys are set to None. Default: None. + """ + super().__init__() + self.keys = keys + self.index = index + self.index_dim = [index_dim] * len(keys) if isinstance(index_dim, int) else index_dim + self.store_deleted_keys = store_deleted_keys + if self.store_deleted_keys is not None and len(keys) > len(self.store_deleted_keys): + self.store_deleted_keys = store_deleted_keys + [None] * (len(keys) - len(self.store_deleted_keys)) + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`DropIndexModule`. + + Parameters + ---------- + sample: dict[str, Any] + Dictionary to look for key and drop indices from. + + Returns + ------- + dict[str, Any] + Dictionary with dropped index. + """ + + for i, key in enumerate(self.keys): + if key not in sample: + continue + # This might be helpful, for instance, in case a single mask is used for all time frames + if sample[key].shape[self.index_dim[i]] == 1: + continue + if self.store_deleted_keys is not None: + deleted_key = self.store_deleted_keys[i] + if deleted_key: + sample[deleted_key] = sample[key].index_select( + self.index_dim[i], + torch.tensor( + [idx for idx in range(sample[key].shape[self.index_dim[i]]) if idx == self.index], + device=sample[key].device, + ), + ) + sample[key] = sample[key].index_select( + self.index_dim[i], + torch.tensor( + [idx for idx in range(sample[key].shape[self.index_dim[i]]) if idx != self.index], + device=sample[key].device, + ), + ) + + return sample + + +class SqueezeKeyModule(DirectModule): + """Squeeze the specified key(s) in the sample. + + Parameters + ---------- + keys: TransformKey + Key(s) to squeeze. + dim: int + Dimension to squeeze. + """ + + def __init__(self, keys: TransformKey, dim: int) -> None: + """Inits :class:`SqueezeKeyModule`. + + Parameters + ---------- + keys: TransformKey + Key(s) to squeeze. + dim: int + Dimension to squeeze. + """ + super().__init__() + self.keys = keys + self.dim = dim + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`SqueezeKeyModule`. + + Parameters + ---------- + sample: dict[str, Any] + Dictionary to look for keys to squeeze. + + Returns + ------- + dict[str, Any] + Dictionary with squeezed specified keys. + """ + for key in self.keys: + if key in sample: + sample[key] = sample[key].squeeze(self.dim) + return sample + + +class RenameKeysModule(DirectModule): + """Rename keys from the sample if present.""" + + def __init__(self, old_keys: list[str], new_keys: list[str]) -> None: + """Inits :class:`RenameKeys`. + + Parameters + ---------- + old_keys: list[str] + Key(s) to rename. + new_keys: list[str] + Key(s) to replace old keys. + """ + super().__init__() + self.old_keys = old_keys + self.new_keys = new_keys + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`RenameKeys`. + + Parameters + ---------- + sample: dict[str, Any] + Dictionary to look for keys and rename them. + + Returns + ------- + dict[str, Any] + Dictionary with renamed specified keys. + """ + for old_key, new_key in zip(self.old_keys, self.new_keys): + if old_key in sample: + sample[new_key] = sample.pop(old_key) + + return sample + + +class PadCoilDimensionModule(DirectModule): + """Pad the coils by zeros to a given number of coils. + + Useful if you want to collate volumes with different coil dimension. + """ + + def __init__( + self, + pad_coils: Optional[int] = None, + key: str = "masked_kspace", + coil_dim: int = 1, + ) -> None: + """Inits :class:`PadCoilDimensionModule`. + + Parameters + ---------- + pad_coils: int, optional + Number of coils to pad to. Default: None. + key: str + Key to pad in sample. Default: "masked_kspace". + coil_dim: int + Coil dimension along which the pad will be done. Default: 0. + """ + super().__init__() + self.num_coils = pad_coils + self.key = key + self.coil_dim = coil_dim + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`PadCoilDimensionModule`. + + Parameters + ---------- + sample: dict[str, Any] + Dictionary with key `self.key`. + + Returns + ------- + sample: dict[str, Any] + Dictionary with padded coils of sample[self.key] if self.num_coils is not None. + """ + if not self.num_coils: + return sample + + if self.key not in sample: + return sample + + data = sample[self.key] + + curr_num_coils = data.shape[self.coil_dim] + if curr_num_coils > self.num_coils: + raise ValueError( + f"Tried to pad to {self.num_coils} coils, but already have {curr_num_coils} for " + f"{sample['filename']}." + ) + if curr_num_coils == self.num_coils: + return sample + + shape = data.shape + num_coils = shape[self.coil_dim] + padding_data_shape = list(shape).copy() + padding_data_shape[self.coil_dim] = max(self.num_coils - num_coils, 0) + zeros = torch.zeros(padding_data_shape, dtype=data.dtype, device=data.device) + sample[self.key] = torch.cat([zeros, data], dim=self.coil_dim) + + return sample + + +class ComputeScalingFactorModule(DirectModule): + """Calculates scaling factor. + + Scaling factor is for the input data based on either to the percentile or to the maximum of `normalize_key`. + """ + + def __init__( + self, + normalize_key: Union[None, TransformKey] = TransformKey.MASKED_KSPACE, + percentile: Union[None, float] = 0.99, + scaling_factor_key: TransformKey = TransformKey.SCALING_FACTOR, + ) -> None: + """Inits :class:`ComputeScalingFactorModule`. + + Parameters + ---------- + normalize_key : TransformKey or None + Key name to compute the data for. If the maximum has to be computed on the ACS, ensure the reconstruction + on the ACS is available (typically `body_coil_image`). Default: "masked_kspace". + percentile : float or None + Rescale data with the given percentile. If None, the division is done by the maximum. Default: 0.99. + scaling_factor_key : TransformKey + Name of how the scaling factor will be stored. Default: "scaling_factor". + """ + super().__init__() + self.normalize_key = normalize_key + self.percentile = percentile + self.scaling_factor_key = scaling_factor_key + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`ComputeScalingFactorModule`. + + Parameters + ---------- + sample: dict[str, Any] + Sample with key `normalize_key` to compute scaling_factor. + + Returns + ------- + sample: dict[str, Any] + Sample with key `scaling_factor_key`. + """ + if self.normalize_key == "scaling_factor": # This is a real-valued given number + scaling_factor = sample["scaling_factor"] + elif not self.normalize_key: + kspace = sample["masked_kspace"] + scaling_factor = torch.tensor([1.0] * kspace.size(0), device=kspace.device, dtype=kspace.dtype) + else: + data = sample[self.normalize_key] + scaling_factor: Union[list, torch.Tensor] = [] + # Compute the maximum and scale the input + if self.percentile: + for _ in range(data.size(0)): + # Used in case the k-space is padded (e.g. for batches) + non_padded_coil_data = data[_][data[_].sum(dim=tuple(range(1, data[_].ndim))).bool()] + tview = -1.0 * T.modulus(non_padded_coil_data).view(-1) + s, _ = torch.kthvalue(tview, int((1 - self.percentile) * tview.size()[0]) + 1) + scaling_factor += [-1.0 * s] + scaling_factor = torch.tensor(scaling_factor, dtype=data.dtype, device=data.device) + else: + scaling_factor = T.modulus(data).amax(dim=list(range(data.ndim))[1:-1]) + sample[self.scaling_factor_key] = scaling_factor + return sample + + +class NormalizeModule(DirectModule): + """Normalize the input data.""" + + def __init__( + self, + scaling_factor_key: TransformKey = TransformKey.SCALING_FACTOR, + keys_to_normalize: Optional[list[TransformKey]] = None, + ) -> None: + """Inits :class:`NormalizeModule`. + + Parameters + ---------- + scaling_factor_key : TransformKey + Name of scaling factor key expected in sample. Default: 'scaling_factor'. + """ + super().__init__() + self.scaling_factor_key = scaling_factor_key + + self.keys_to_normalize = ( + [ + "masked_kspace", + "target", + "kspace", + "body_coil_image", # sensitivity_map does not require normalization. + "initial_image", + "initial_kspace", + ] + if keys_to_normalize is None + else keys_to_normalize + ) + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`NormalizeModule`. + + Parameters + ---------- + sample: dict[str, Any] + Sample to normalize. + + Returns + ------- + sample: dict[str, Any] + Sample with normalized values if their respective key is in `keys_to_normalize` and key + `scaling_factor_key` exists in sample. + """ + scaling_factor = sample.get(self.scaling_factor_key, None) + # Normalize data + if scaling_factor is not None: + for key in sample.keys(): + if key not in self.keys_to_normalize: + continue + sample[key] = T.safe_divide( + sample[key], + scaling_factor.reshape(-1, *[1 for _ in range(sample[key].ndim - 1)]), + ) + + sample["scaling_diff"] = 0.0 + return sample + + +class WhitenDataModule(DirectModule): + """Whitens complex data Module.""" + + def __init__(self, epsilon: float = 1e-10, key: str = "complex_image") -> None: + """Inits :class:`WhitenDataModule`. + + Parameters + ---------- + epsilon: float + Default: 1e-10. + key: str + Key to whiten. Default: "complex_image". + """ + super().__init__() + self.epsilon = epsilon + self.key = key + + def complex_whiten(self, complex_image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Whiten complex image. + + Parameters + ---------- + complex_image: torch.Tensor + Complex image tensor to whiten. + + Returns + ------- + mean, std, whitened_image: tuple[torch.Tensor, torch.Tensor, torch.Tensor] + """ + # From: https://github.com/facebookresearch/fastMRI + # blob/da1528585061dfbe2e91ebbe99a5d4841a5c3f43/banding_removal/fastmri/data/transforms.py#L464 # noqa + real = complex_image[..., 0] + imag = complex_image[..., 1] + + # Center around mean. + mean = complex_image.mean() + centered_complex_image = complex_image - mean + + # Determine covariance between real and imaginary. + n_elements = real.nelement() + real_real = (real.mul(real).sum() - real.mean().mul(real.mean())) / n_elements + real_imag = (real.mul(imag).sum() - real.mean().mul(imag.mean())) / n_elements + imag_imag = (imag.mul(imag).sum() - imag.mean().mul(imag.mean())) / n_elements + eig_input = torch.Tensor([[real_real, real_imag], [real_imag, imag_imag]]) + + # Remove correlation by rotating around covariance eigenvectors. + eig_values, eig_vecs = torch.linalg.eig(eig_input) + + # Scale by eigenvalues for unit variance. + std = (eig_values.real + self.epsilon).sqrt() + whitened_image = torch.matmul(centered_complex_image, eig_vecs.real) / std + + return mean, std, whitened_image + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Forward pass of :class:`WhitenDataModule`. + + Parameters + ---------- + sample: dict[str, Any] + Sample with key `key`. + + Returns + ------- + sample: dict[str, Any] + Sample with value of `key` whitened. + """ + _, _, whitened_image = self.complex_whiten(sample[self.key]) + sample[self.key] = whitened_image + return sample + + +class AddTargetAcceleration(DirectTransform): + """This will replace the acceleration factor in the sample with the target acceleration factor.""" + + def __init__(self, target_acceleration: float): + super().__init__() + self.target_acceleration = target_acceleration + + def __call__(self, sample: dict[str, Any]): + sample["acceleration"][:] = self.target_acceleration + return sample + + +class ModuleWrapper: + class SubWrapper: + def __init__(self, transform: Callable, toggle_dims: bool) -> None: + self.toggle_dims = toggle_dims + self._transform = transform + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + if self.toggle_dims: + for k, v in sample.items(): + if isinstance(v, (torch.Tensor, np.ndarray)): + sample[k] = v[None] + else: + sample[k] = [v] + + sample = self._transform.forward(sample) + + if self.toggle_dims: + for k, v in sample.items(): + if isinstance(v, (torch.Tensor, np.ndarray)): + sample[k] = v.squeeze(0) + else: + sample[k] = v[0] + + return sample + + def __repr__(self) -> str: + return self._transform.__repr__() + + def __init__(self, module: Callable, toggle_dims: bool) -> None: + self._module = module + self.toggle_dims = toggle_dims + + def __call__(self, *args, **kwargs) -> SubWrapper: + return self.SubWrapper(self._module(*args, **kwargs), toggle_dims=self.toggle_dims) + + +ApplyMask = ModuleWrapper(ApplyMaskModule, toggle_dims=False) +ComputeImage = ModuleWrapper(ComputeImageModule, toggle_dims=True) +EstimateSensitivityMap = ModuleWrapper(EstimateSensitivityMapModule, toggle_dims=True) +CopyKeys = ModuleWrapper(CopyKeysModule, toggle_dims=False) +DeleteKeys = ModuleWrapper(DeleteKeysModule, toggle_dims=False) +RenameKeys = ModuleWrapper(RenameKeysModule, toggle_dims=False) +IndexSelection = ModuleWrapper(IndexSelectionModule, toggle_dims=False) +DropIndex = ModuleWrapper(DropIndexModule, toggle_dims=False) +SqueezeKey = ModuleWrapper(SqueezeKeyModule, toggle_dims=False) +CompressCoil = ModuleWrapper(CompressCoilModule, toggle_dims=True) +PadCoilDimension = ModuleWrapper(PadCoilDimensionModule, toggle_dims=True) +ComputeScalingFactor = ModuleWrapper(ComputeScalingFactorModule, toggle_dims=True) +Normalize = ModuleWrapper(NormalizeModule, toggle_dims=False) +WhitenData = ModuleWrapper(WhitenDataModule, toggle_dims=False) +GaussianMaskSplitter = ModuleWrapper(GaussianMaskSplitterModule, toggle_dims=True) +UniformMaskSplitter = ModuleWrapper(UniformMaskSplitterModule, toggle_dims=True) +Displacement = ModuleWrapper(DisplacementModule, toggle_dims=True) +RandomElasticDeformation = ModuleWrapper(RandomElasticDeformationModule, toggle_dims=True) + + +class ToTensor(DirectTransform): + """Transforms all np.array-like values in sample to torch.tensors.""" + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calls :class:`ToTensor`. + + Parameters + ---------- + sample: dict[str, Any] + Contains key 'kspace' with value a np.array of shape (coil, height, width) (2D) + or (coil, slice, height, width) (3D) + + Returns + ------- + sample: dict[str, Any] + Contains key 'kspace' with value a torch.Tensor of shape (coil, height, width) (2D) + or (coil, slice, height, width) (3D) + """ + + ndim = sample["kspace"].ndim - 1 + + if ndim not in [2, 3]: + raise ValueError(f"Can only cast 2D and 3D data (+coil) to tensor. Got {ndim}.") + + # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2) + sample["kspace"] = T.to_tensor(sample["kspace"]).float() + # Sensitivity maps are not necessarily available in the dataset. + if "initial_kspace" in sample: + # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2) + sample["initial_kspace"] = T.to_tensor(sample["initial_kspace"]).float() + if "initial_image" in sample: + # Shape: 2D: (height, width), 3D: (slice, height, width) + sample["initial_image"] = T.to_tensor(sample["initial_image"]).float() + + if "sensitivity_map" in sample: + # Shape: 2D: (coil, height, width, complex=2), 3D: (coil, slice, height, width, complex=2) + sample["sensitivity_map"] = T.to_tensor(sample["sensitivity_map"]).float() + if "target" in sample: + # Shape: 2D: (coil, height, width), 3D: (coil, slice, height, width) + sample["target"] = torch.from_numpy(sample["target"]).float() + if "sampling_mask" in sample: + sample["sampling_mask"] = torch.from_numpy(sample["sampling_mask"]).bool() + if "acs_mask" in sample: + sample["acs_mask"] = torch.from_numpy(sample["acs_mask"]).bool() + if "scaling_factor" in sample: + sample["scaling_factor"] = torch.tensor(sample["scaling_factor"]).float() + if "loglikelihood_scaling" in sample: + # Shape: (coil, ) + sample["loglikelihood_scaling"] = torch.from_numpy(np.asarray(sample["loglikelihood_scaling"])).float() + + return sample + + +class RegistrationSimulateReferenceType(DirectEnum): + FROM_KEY = "from_key" + ELASTIC = "elastic" + + +# pylint: disable=too-many-arguments +def build_supervised_mri_transforms( + forward_operator: Callable, + backward_operator: Callable, + mask_func: Optional[Callable], + target_acceleration: Optional[float] = None, + crop: Optional[Union[tuple[int, int], str]] = None, + crop_type: Optional[str] = "uniform", + rescale: Optional[Union[tuple[int, int], list[int]]] = None, + rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST, + rescale_2d_if_3d: Optional[bool] = False, + pad: Optional[Union[tuple[int, int], list[int]]] = None, + image_center_crop: bool = True, + random_rotation_degrees: Optional[Sequence[int]] = (-90, 90), + random_rotation_probability: float = 0.0, + random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM, + random_flip_probability: float = 0.0, + random_reverse_probability: float = 0.0, + padding_eps: float = 0.0001, + estimate_body_coil_image: bool = False, + estimate_sensitivity_maps: bool = True, + sensitivity_maps_type: SensitivityMapType = SensitivityMapType.RSS_ESTIMATE, + sensitivity_maps_gaussian: Optional[float] = None, + sensitivity_maps_espirit_threshold: Optional[float] = 0.05, + sensitivity_maps_espirit_kernel_size: Optional[int] = 6, + sensitivity_maps_espirit_crop: Optional[float] = 0.95, + sensitivity_maps_espirit_max_iters: Optional[int] = 30, + use_acs_as_mask: bool = False, + delete_acs: bool = True, + delete_kspace: bool = True, + image_recon_type: ReconstructionType = ReconstructionType.RSS, + compress_coils: Optional[int] = None, + pad_coils: Optional[int] = None, + scaling_key: TransformKey = TransformKey.MASKED_KSPACE, + scale_percentile: Optional[float] = 0.99, + registration: bool = False, + registration_simulate_reference: Optional[RegistrationSimulateReferenceType] = None, + registration_simulate_elastic_sigma: float = 3.0, + registration_simulate_elastic_points: int = 3, + registration_simulate_elastic_rotate: float = 0.0, + registration_simulate_elastic_zoom: float = 0.0, + registration_estimate_displacement: bool = True, + registration_simulate_reference_from_key_index: int = 0, + registration_moving_key: TransformKey = TransformKey.TARGET, + demons_filter_type: DemonsFilterType = DemonsFilterType.SYMMETRIC_FORCES, + demons_num_iterations: int = 100, + demons_smooth_displacement_field: bool = True, + demons_standard_deviations: float = 1.5, + demons_intensity_difference_threshold: Optional[float] = None, + demons_maximum_rms_error: Optional[float] = None, + use_seed: bool = True, +) -> DirectTransform: + r"""Builds supervised MRI transforms. + + More specifically, the following transformations are applied: + + * Converts input to (complex-valued) tensor. + * Applies k-space (center) crop if requested. + * Applies k-space rescaling if requested. + * Applies k-space padding if requested. + * Applies random augmentations (rotation, flip, reverse) if requested. + * Adds a sampling mask if `mask_func` is defined. + * Compreses the coil dimension if requested. + * Pads the coil dimension if requested. + * Adds coil sensitivities and / or the body coil_image + * Masks the fully sampled k-space, if there is a mask function or a mask in the sample. + * Computes a scaling factor based on the masked k-space and normalizes data. + * Computes a target (image). + * Deletes the acs mask and the fully sampled k-space if requested. + + Parameters + ---------- + forward_operator : Callable + The forward operator, e.g. some form of FFT (centered or uncentered). + backward_operator : Callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + mask_func : Callable or None + A function which creates a sampling mask of the appropriate shape. + target_acceleration : float, optional + Target acceleration factor. Default: None. + crop : tuple[int, int] or str, Optional + If not None, this will transform the "kspace" to an image domain, crop it, and transform it back. + If a tuple of integers is given then it will crop the backprojected kspace to that size. If + "reconstruction_size" is given, then it will crop the backprojected kspace according to it, but + a key "reconstruction_size" must be present in the sample. Default: None. + crop_type : Optional[str] + Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform". + rescale : tuple or list, optional + If not None, this will transform the "kspace" to the image domain, rescale it, and transform it back. + Must correspond to (height, width). This is ignored if `rescale` is None. Default: None. + It is not recommended to be used in combination with `crop`. + rescale_mode : RescaleMode + Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, + RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are + supported for 2D or 3D data. Default: RescaleMode.NEAREST. + rescale_2d_if_3d : bool, optional + If True and k-space data is 3D, rescaling will be done only on the height + and width dimensions, by combining the slice/time dimension with the batch dimension. + This is ignored if `rescale` is None. Default: False. + pad : tuple or list, optional + If not None, this will zero-pad the "kspace" to the given size. Must correspond to (height, width) + or (slice/time, height, width). Default: None. + image_center_crop : bool + If True the backprojected kspace will be cropped around the center, otherwise randomly. + This will be ignored if `crop` is None. Default: True. + random_rotation_degrees : Sequence[int], optional + Default: (-90, 90). + random_rotation_probability : float, optional + If greater than 0.0, random rotations will be applied of `random_rotation_degrees` degrees, with probability + `random_rotation_probability`. Default: 0.0. + random_flip_type : RandomFlipType, optional + Default: RandomFlipType.RANDOM. + random_flip_probability : float, optional + If greater than 0.0, random rotation of `random_flip_type` type, with probability `random_flip_probability`. + Default: 0.0. + random_reverse_probability : float + If greater than 0.0, will perform random reversion along the time or slice dimension (2) with probability + `random_reverse_probability`. Default: 0.0. + padding_eps: float + Padding epsilon. Default: 0.0001. + estimate_body_coil_image : bool + Estimate body coil image. Default: False. + estimate_sensitivity_maps : bool + Estimate sensitivity maps using the acs region. Default: True. + sensitivity_maps_type: sensitivity_maps_type + Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or SensitivityMapType.ESPIRIT. + Will be ignored if `estimate_sensitivity_maps` is False. Default: SensitivityMapType.RSS_ESTIMATE. + sensitivity_maps_gaussian : float + Optional sigma for gaussian weighting of sensitivity map. + sensitivity_maps_espirit_threshold : float, optional + Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. + Default: 0.05. + sensitivity_maps_espirit_kernel_size : int, optional + Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 6. + sensitivity_maps_espirit_crop : float, optional + Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 0.95. + sensitivity_maps_espirit_max_iters : int, optional + Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30. + use_acs_as_mask : bool + If True, will use the acs region as the mask. Default: False. + delete_acs : bool + If True will delete key `acs_mask`. Default: True. + delete_kspace : bool + If True will delete key `kspace` (fully sampled k-space). Default: True. + image_recon_type : ReconstructionType + Type to reconstruct target image. Default: ReconstructionType.RSS. + compress_coils : int, optional + Number of coils to compress input k-space. It is not recommended to be used in combination with `pad_coils`. + Default: None. + pad_coils : int + Number of coils to pad data to. + scaling_key : TransformKey + Key in sample to scale scalable items in sample. Default: TransformKey.MASKED_KSPACE. + scale_percentile : float, optional + Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99. + registration : bool + If True, will compute a displacement field between the target and the moving image. Default: False. + registration_simulate_reference : RegistrationSimulateReferenceType + If not None, will simulate a reference image for displacement field computation. Otherwise, this expects a key + in the sample. Can be RegistrationSimulateReferenceType.FROM_KEY or RegistrationSimulateReferenceType.ELASTIC. + Default: None. + registration_simulate_elastic_sigma : float + Standard deviation for the elastic simulation. Default: 3.0. + registration_simulate_elastic_points : int + Number of points for the elastic simulation. Default: 3. + registration_simulate_elastic_rotate : float + Rotation for the elastic simulation. Default: 0.0. + registration_estimate_displacement : bool + If True, will estimate the displacement field between the target and the moving image using the + demons algorithm. Default: True + registration_simulate_elastic_zoom : float + Zoom for the elastic simulation. Default: 0.0. + registration_simulate_reference_from_key_index : int + Index to drop from the key to simulate the reference image. Default: 0. + demons_filter_type : DemonsFilterType + Type of filter to apply to the displacement field. Default: DemonsFilterType.SYMMETRIC_FORCES. + demons_num_iterations : int + Number of iterations for the demons algorithm. Default: 100. + demons_smooth_displacement_field : bool + If True, will smooth the displacement field. Default: True. + demons_standard_deviations : float + Standard deviation for the smoothing of the displacement field. Default: 1.5. + demons_intensity_difference_threshold : float, optional + Intensity difference threshold for the demons algorithm. Default: None. + demons_maximum_rms_error : float, optional + Maximum RMS error for the demons algorithm. Default: None. + use_seed : bool + If true, a pseudo-random number based on the filename is computed so that every slice of the volume get + the same mask every time. Default: True. + + Returns + ------- + DirectTransform + An MRI transformation object. + """ + mri_transforms: list[Callable] = [ToTensor()] + if crop: + mri_transforms += [ + CropKspace( + crop=crop, + forward_operator=forward_operator, + backward_operator=backward_operator, + image_space_center_crop=image_center_crop, + random_crop_sampler_type=crop_type, + random_crop_sampler_use_seed=use_seed, + ) + ] + if rescale: + mri_transforms += [ + RescaleKspace( + shape=rescale, + forward_operator=forward_operator, + backward_operator=backward_operator, + rescale_mode=rescale_mode, + rescale_2d_if_3d=rescale_2d_if_3d, + kspace_key=KspaceKey.KSPACE, + ) + ] + if pad: + mri_transforms += [ + PadKspace( + pad_shape=pad, + forward_operator=forward_operator, + backward_operator=backward_operator, + kspace_key=KspaceKey.KSPACE, + ) + ] + if random_rotation_probability > 0.0: + mri_transforms += [ + RandomRotation( + degrees=random_rotation_degrees, + p=random_rotation_probability, + keys_to_rotate=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP), + ) + ] + if random_flip_probability > 0.0: + mri_transforms += [ + RandomFlip( + flip=random_flip_type, + p=random_flip_probability, + keys_to_flip=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP), + ) + ] + if random_reverse_probability > 0.0: + mri_transforms += [ + RandomReverse( + p=random_reverse_probability, + keys_to_reverse=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP), + ) + ] + if padding_eps > 0.0: + mri_transforms += [ + ComputeZeroPadding(KspaceKey.KSPACE, "padding", padding_eps), + ApplyZeroPadding(KspaceKey.KSPACE, "padding"), + ] + if mask_func: + mri_transforms += [ + CreateSamplingMask( + mask_func, + shape=(None if (isinstance(crop, str)) else crop), + use_seed=use_seed, + return_acs=estimate_sensitivity_maps, + ), + ] + if use_acs_as_mask: + mri_transforms += [CopyKeys(keys=[TransformKey.ACS_MASK], new_keys=[TransformKey.SAMPLING_MASK])] + if target_acceleration: + mri_transforms += [AddTargetAcceleration(target_acceleration)] + if compress_coils: + mri_transforms += [CompressCoil(num_coils=compress_coils, kspace_key=KspaceKey.KSPACE)] + if pad_coils: + mri_transforms += [PadCoilDimension(pad_coils=pad_coils, key=KspaceKey.KSPACE)] + + if estimate_body_coil_image and mask_func is not None: + mri_transforms.append(EstimateBodyCoilImage(mask_func, backward_operator=backward_operator, use_seed=use_seed)) + mri_transforms += [ + ApplyMask( + sampling_mask_key=TransformKey.ACS_MASK, + input_kspace_key=KspaceKey.KSPACE, + target_kspace_key=KspaceKey.ACS_KSPACE, + ), + ] + if estimate_sensitivity_maps: + mri_transforms += [ + EstimateSensitivityMap( + kspace_key=KspaceKey.ACS_KSPACE, + backward_operator=backward_operator, + type_of_map=sensitivity_maps_type, + gaussian_sigma=sensitivity_maps_gaussian, + espirit_threshold=sensitivity_maps_espirit_threshold, + espirit_kernel_size=sensitivity_maps_espirit_kernel_size, + espirit_crop=sensitivity_maps_espirit_crop, + espirit_max_iters=sensitivity_maps_espirit_max_iters, + ) + ] + mri_transforms += [ + ApplyMask( + sampling_mask_key=TransformKey.SAMPLING_MASK, + input_kspace_key=KspaceKey.KSPACE, + target_kspace_key=KspaceKey.MASKED_KSPACE, + ), + ] + if registration: + if registration_simulate_reference is not None: + mri_transforms += [ + DropIndex( + keys=[ + TransformKey.KSPACE, + TransformKey.ACS_KSPACE, + TransformKey.MASKED_KSPACE, + TransformKey.ACS_MASK, + TransformKey.SAMPLING_MASK, + TransformKey.PADDING, + TransformKey.SENSITIVITY_MAP, + TransformKey.ACCELERATION, + TransformKey.CENTER_FRACTION, + ], + index=registration_simulate_reference_from_key_index, + index_dim=1, + store_deleted_keys=[TransformKey.REFERENCE_KSPACE], + ) + ] + mri_transforms += [ + ComputeScalingFactor( + normalize_key=scaling_key, percentile=scale_percentile, scaling_factor_key=TransformKey.SCALING_FACTOR + ), + Normalize( + scaling_factor_key=TransformKey.SCALING_FACTOR, + keys_to_normalize=[ + KspaceKey.ACS_KSPACE, + KspaceKey.KSPACE, + KspaceKey.MASKED_KSPACE, + KspaceKey.REFERENCE_KSPACE, + ], # Only these two keys are in the sample here + ), + ] + mri_transforms += [ + ComputeImage( + kspace_key=KspaceKey.KSPACE, + target_key=TransformKey.TARGET, + backward_operator=backward_operator, + type_reconstruction=image_recon_type, + ) + ] + if registration: + if registration_simulate_reference is not None: + mri_transforms += [ + ComputeImage( + kspace_key=KspaceKey.REFERENCE_KSPACE, + target_key=TransformKey.REFERENCE_IMAGE, + backward_operator=backward_operator, + type_reconstruction=image_recon_type, + ), + SqueezeKey(keys=[TransformKey.REFERENCE_IMAGE], dim=0), + ] + if registration_simulate_reference == RegistrationSimulateReferenceType.ELASTIC: + mri_transforms += [ + RandomElasticDeformation( + image_key=TransformKey.REFERENCE_IMAGE, + target_key=TransformKey.REFERENCE_IMAGE, + use_seed=use_seed, + sigma=registration_simulate_elastic_sigma, + points=registration_simulate_elastic_points, + rotate=registration_simulate_elastic_rotate, + zoom=registration_simulate_elastic_zoom, + ) + ] + if registration_estimate_displacement: + mri_transforms += [ + Displacement( + transform_type=DisplacementTransformType.MULTISCALE_DEMONS, + demons_filter_type=demons_filter_type, + demons_num_iterations=demons_num_iterations, + demons_smooth_displacement_field=demons_smooth_displacement_field, + demons_standard_deviations=demons_standard_deviations, + demons_intensity_difference_threshold=demons_intensity_difference_threshold, + demons_maximum_rms_error=demons_maximum_rms_error, + reference_image_key=TransformKey.REFERENCE_IMAGE, + moving_image_key=registration_moving_key, + ) + ] + if delete_acs: + mri_transforms += [DeleteKeys(keys=[TransformKey.ACS_MASK, KspaceKey.ACS_KSPACE])] + if delete_kspace: + mri_transforms += [DeleteKeys(keys=[KspaceKey.KSPACE])] + + return Compose(mri_transforms) + + +class TransformsType(DirectEnum): + SUPERVISED = "supervised" + SSL_SSDU = "ssl_ssdu" + + +# pylint: disable=too-many-arguments +def build_mri_transforms( + forward_operator: Callable, + backward_operator: Callable, + mask_func: Optional[Callable], + target_acceleration: Optional[float] = None, + crop: Optional[Union[tuple[int, int], str]] = None, + crop_type: Optional[str] = "uniform", + rescale: Optional[Union[tuple[int, int], list[int]]] = None, + rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST, + rescale_2d_if_3d: Optional[bool] = False, + pad: Optional[Union[tuple[int, int], list[int]]] = None, + image_center_crop: bool = True, + random_rotation_degrees: Optional[Sequence[int]] = (-90, 90), + random_rotation_probability: float = 0.0, + random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM, + random_flip_probability: float = 0.0, + random_reverse_probability: float = 0.0, + padding_eps: float = 0.0001, + estimate_body_coil_image: bool = False, + estimate_sensitivity_maps: bool = True, + sensitivity_maps_type: SensitivityMapType = SensitivityMapType.RSS_ESTIMATE, + sensitivity_maps_gaussian: Optional[float] = None, + sensitivity_maps_espirit_threshold: Optional[float] = 0.05, + sensitivity_maps_espirit_kernel_size: Optional[int] = 6, + sensitivity_maps_espirit_crop: Optional[float] = 0.95, + sensitivity_maps_espirit_max_iters: Optional[int] = 30, + use_acs_as_mask: bool = False, + delete_acs: bool = True, + delete_kspace: bool = True, + image_recon_type: ReconstructionType = ReconstructionType.RSS, + compress_coils: Optional[int] = None, + pad_coils: Optional[int] = None, + scaling_key: TransformKey = TransformKey.MASKED_KSPACE, + scale_percentile: Optional[float] = 0.99, + registration: bool = False, + registration_simulate_reference: Optional[RegistrationSimulateReferenceType] = None, + registration_simulate_elastic_sigma: float = 3.0, + registration_simulate_elastic_points: int = 3, + registration_simulate_elastic_rotate: float = 0.0, + registration_simulate_elastic_zoom: float = 0.0, + registration_estimate_displacement: bool = True, + registration_simulate_reference_from_key_index: int = 0, + registration_moving_key: TransformKey = TransformKey.TARGET, + demons_filter_type: DemonsFilterType = DemonsFilterType.SYMMETRIC_FORCES, + demons_num_iterations: int = 100, + demons_smooth_displacement_field: bool = True, + demons_standard_deviations: float = 1.5, + demons_intensity_difference_threshold: Optional[float] = None, + demons_maximum_rms_error: Optional[float] = None, + use_seed: bool = True, + transforms_type: Optional[TransformsType] = TransformsType.SUPERVISED, + mask_split_ratio: Union[float, list[float], tuple[float, ...]] = 0.4, + mask_split_acs_region: Union[list[int], tuple[int, int]] = (0, 0), + mask_split_keep_acs: Optional[bool] = False, + mask_split_type: MaskSplitterType = MaskSplitterType.GAUSSIAN, + mask_split_gaussian_std: float = 3.0, + mask_split_half_direction: HalfSplitType = HalfSplitType.VERTICAL, +) -> DirectTransform: + r"""Build transforms for MRI. + + More specifically, the following transformations are applied: + + * Converts input to (complex-valued) tensor. + * Applies k-space (center) crop if requested. + * Applies k-space rescaling if requested. + * Applies k-space padding if requested. + * Applies random augmentations (rotation, flip, reverse) if requested. + * Adds a sampling mask if `mask_func` is defined. + * Compreses the coil dimension if requested. + * Pads the coil dimension if requested. + * Adds coil sensitivities and / or the body coil_image + * Masks the fully sampled k-space, if there is a mask function or a mask in the sample. + * Computes a scaling factor based on the masked k-space and normalizes data. + * Computes a target (image). + * Deletes the acs mask and the fully sampled k-space if requested. + * Splits the mask if requested for self-supervised learning. + + Parameters + ---------- + forward_operator : Callable + The forward operator, e.g. some form of FFT (centered or uncentered). + backward_operator : Callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + mask_func : Callable or None + A function which creates a sampling mask of the appropriate shape. + target_acceleration : float, optional + Target acceleration factor. Default: None. + crop : tuple[int, int] or str, Optional + If not None, this will transform the "kspace" to an image domain, crop it, and transform it back. + If a tuple of integers is given then it will crop the backprojected kspace to that size. If + "reconstruction_size" is given, then it will crop the backprojected kspace according to it, but + a key "reconstruction_size" must be present in the sample. Default: None. + crop_type : Optional[str] + Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform". + rescale : tuple or list, optional + If not None, this will transform the "kspace" to the image domain, rescale it, and transform it back. + Must correspond to (height, width). This is ignored if `rescale` is None. Default: None. + It is not recommended to be used in combination with `crop`. + rescale_mode : RescaleMode + Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, + RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are + supported for 2D or 3D data. Default: RescaleMode.NEAREST. + rescale_2d_if_3d : bool, optional + If True and k-space data is 3D, rescaling will be done only on the height + and width dimensions, by combining the slice/time dimension with the batch dimension. + This is ignored if `rescale` is None. Default: False. + pad : tuple or list, optional + If not None, this will zero-pad the "kspace" to the given size. Must correspond to (height, width) + or (slice/time, height, width). Default: None. + image_center_crop : bool + If True the backprojected kspace will be cropped around the center, otherwise randomly. + This will be ignored if `crop` is None. Default: True. + random_rotation_degrees : Sequence[int], optional + Default: (-90, 90). + random_rotation_probability : float, optional + If greater than 0.0, random rotations will be applied of `random_rotation_degrees` degrees, with probability + `random_rotation_probability`. Default: 0.0. + random_flip_type : RandomFlipType, optional + Default: RandomFlipType.RANDOM. + random_flip_probability : float, optional + If greater than 0.0, random rotation of `random_flip_type` type, with probability `random_flip_probability`. + Default: 0.0. + random_reverse_probability : float + If greater than 0.0, will perform random reversion along the time or slice dimension (2) with probability + `random_reverse_probability`. Default: 0.0. + padding_eps: float + Padding epsilon. Default: 0.0001. + estimate_body_coil_image : bool + Estimate body coil image. Default: False. + estimate_sensitivity_maps : bool + Estimate sensitivity maps using the acs region. Default: True. + sensitivity_maps_type: sensitivity_maps_type + Can be SensitivityMapType.RSS_ESTIMATE, SensitivityMapType.UNIT or SensitivityMapType.ESPIRIT. + Will be ignored if `estimate_sensitivity_maps` is False. Default: SensitivityMapType.RSS_ESTIMATE. + sensitivity_maps_gaussian : float + Optional sigma for gaussian weighting of sensitivity map. + sensitivity_maps_espirit_threshold : float, optional + Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. + Default: 0.05. + sensitivity_maps_espirit_kernel_size : int, optional + Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 6. + sensitivity_maps_espirit_crop : float, optional + Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 0.95. + sensitivity_maps_espirit_max_iters : int, optional + Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30. + use_acs_as_mask : bool + If True, will use the acs region as the mask. Default: False. + delete_acs : bool + If True will delete key `acs_mask`. Default: True. + delete_kspace : bool + If True will delete key `kspace` (fully sampled k-space). Default: True. + image_recon_type : ReconstructionType + Type to reconstruct target image. Default: ReconstructionType.RSS. + compress_coils : int, optional + Number of coils to compress input k-space. It is not recommended to be used in combination with `pad_coils`. + Default: None. + pad_coils : int + Number of coils to pad data to. + scaling_key : TransformKey + Key in sample to scale scalable items in sample. Default: TransformKey.MASKED_KSPACE. + scale_percentile : float, optional + Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99. + registration : bool + If True, will compute a displacement field between the target and the moving image. Default: False. + registration_simulate_reference : RegistrationSimulateReferenceType + If not None, will simulate a reference image for displacement field computation. Otherwise, this expects a key + in the sample. Can be RegistrationSimulateReferenceType.FROM_KEY or RegistrationSimulateReferenceType.ELASTIC. + Default: None. + registration_simulate_elastic_sigma : float + Standard deviation for the elastic simulation. Default: 3.0. + registration_simulate_elastic_points : int + Number of points for the elastic simulation. Default: 3. + registration_simulate_elastic_rotate : float + Rotation for the elastic simulation. Default: 0.0. + registration_simulate_elastic_zoom : float + Zoom for the elastic simulation. Default: 0.0. + registration_estimate_displacement : bool + If True, will estimate the displacement field between the target and the moving image using the + demons algorithm. Default: True + registration_simulate_reference_from_key_index : int + Index to drop from the key to simulate the reference image. Default: 0. + registration_moving_key : TransformKey + Key in sample to compute displacement field from. Default: TransformKey.TARGET. + demons_filter_type : DemonsFilterType + Type of filter to apply to the displacement field. Default: DemonsFilterType.SYMMETRIC_FORCES. + demons_num_iterations : int + Number of iterations for the demons algorithm. Default: 100. + demons_smooth_displacement_field : bool + If True, will smooth the displacement field. Default: True. + demons_standard_deviations : float + Standard deviation for the smoothing of the displacement field. Default: 1.5. + demons_intensity_difference_threshold : float, optional + Intensity difference threshold for the demons algorithm. Default: None. + demons_maximum_rms_error : float, optional + Maximum RMS error for the demons algorithm. Default: None. + use_seed : bool + If true, a pseudo-random number based on the filename is computed so that every slice of the volume get + the same mask every time. Default: True. + transforms_type : TransformsType, optional + Can be `TransformsType.SUPERVISED` for supervised learning transforms or `TransformsType.SSL_SSDU` for + self-supervised learning transforms. Default: `TransformsType.SUPERVISED`. + mask_split_ratio : Union[float, list[float], tuple[float, ...]] + The ratio(s) of the sampling mask splitting. If `transforms_type` is TransformsKey.SUPERVISED, this is ignored. + mask_split_acs_region : Union[list[int], tuple[int, int]] + A rectangle for the acs region that will be used in the input mask. This applies only if `transforms_type` is + set to TransformsKey.SSL_SSDU. Default: (0, 0). + mask_split_keep_acs : Optional[bool] + If True, acs region according to the "acs_mask" of the sample will be used in both mask splits. + This applies only if `transforms_type` is set to TransformsKey.SSL_SSDU. Default: False. + mask_split_type : MaskSplitterType + How the sampling mask will be split. Can be MaskSplitterType.UNIFORM, MaskSplitterType.GAUSSIAN, or + MaskSplitterType.HALF. Default: MaskSplitterType.GAUSSIAN. This applies only if `transforms_type` is + set to TransformsKey.SSL_SSDU. Default: MaskSplitterType.GAUSSIAN. + mask_split_gaussian_std : float + Standard deviation of gaussian mask splitting. This applies only if `transforms_type` is + set to TransformsKey.SSL_SSDU. Ignored if `mask_split_type` is not set to MaskSplitterType.GAUSSIAN. + Default: 3.0. + mask_split_half_direction : HalfSplitType + Split type if `mask_split_type` is `MaskSplitterType.HALF`. Can be `HalfSplitType.VERTICAL`, + `HalfSplitType.HORIZONTAL`, `HalfSplitType.DIAGONAL_LEFT` or `HalfSplitType.DIAGONAL_RIGHT`. + This applies only if `transforms_type` is set to `TransformsKey.SSL_SSDU`. Ignored if `mask_split_type` is not + set to `MaskSplitterType.HALF`. Default: `HalfSplitType.VERTICAL`. + + Returns + ------- + DirectTransform + An MRI transformation object. + """ + logger = logging.getLogger(build_mri_transforms.__name__) + logger.info("Creating %s MRI transforms.", transforms_type) + + if crop and rescale: + logger.warning( + "Rescale and crop are both given. Rescale will be applied after cropping. This is not recommended." + ) + + if compress_coils and pad_coils: + logger.warning( + "Compress coils and pad coils are both given. Compress coils will be applied before padding. " + "This is not recommended." + ) + + mri_transforms = build_supervised_mri_transforms( + forward_operator=forward_operator, + backward_operator=backward_operator, + mask_func=mask_func, + target_acceleration=target_acceleration, + crop=crop, + crop_type=crop_type, + rescale=rescale, + rescale_mode=rescale_mode, + rescale_2d_if_3d=rescale_2d_if_3d, + pad=pad, + image_center_crop=image_center_crop, + random_rotation_degrees=random_rotation_degrees, + random_rotation_probability=random_rotation_probability, + random_flip_type=random_flip_type, + random_flip_probability=random_flip_probability, + random_reverse_probability=random_reverse_probability, + padding_eps=padding_eps, + estimate_sensitivity_maps=estimate_sensitivity_maps, + sensitivity_maps_type=sensitivity_maps_type, + estimate_body_coil_image=estimate_body_coil_image, + sensitivity_maps_gaussian=sensitivity_maps_gaussian, + sensitivity_maps_espirit_threshold=sensitivity_maps_espirit_threshold, + sensitivity_maps_espirit_kernel_size=sensitivity_maps_espirit_kernel_size, + sensitivity_maps_espirit_crop=sensitivity_maps_espirit_crop, + sensitivity_maps_espirit_max_iters=sensitivity_maps_espirit_max_iters, + use_acs_as_mask=use_acs_as_mask, + delete_acs=delete_acs if transforms_type == TransformsType.SUPERVISED else False, + delete_kspace=delete_kspace if transforms_type == TransformsType.SUPERVISED else False, + image_recon_type=image_recon_type, + compress_coils=compress_coils, + pad_coils=pad_coils, + scaling_key=scaling_key, + scale_percentile=scale_percentile, + registration=registration, + registration_simulate_reference=registration_simulate_reference, + registration_simulate_elastic_sigma=registration_simulate_elastic_sigma, + registration_simulate_elastic_points=registration_simulate_elastic_points, + registration_simulate_elastic_rotate=registration_simulate_elastic_rotate, + registration_simulate_elastic_zoom=registration_simulate_elastic_zoom, + registration_estimate_displacement=registration_estimate_displacement, + registration_simulate_reference_from_key_index=registration_simulate_reference_from_key_index, + registration_moving_key=registration_moving_key, + demons_filter_type=demons_filter_type, + demons_num_iterations=demons_num_iterations, + demons_smooth_displacement_field=demons_smooth_displacement_field, + demons_standard_deviations=demons_standard_deviations, + demons_intensity_difference_threshold=demons_intensity_difference_threshold, + demons_maximum_rms_error=demons_maximum_rms_error, + use_seed=use_seed, + ).transforms + + mri_transforms += [AddBooleanKeysModule(["is_ssl"], [transforms_type != TransformsType.SUPERVISED])] + + if transforms_type == TransformsType.SUPERVISED: + return Compose(mri_transforms) + + mask_splitter_kwargs = { + "ratio": mask_split_ratio, + "acs_region": mask_split_acs_region, + "keep_acs": mask_split_keep_acs, + "use_seed": use_seed, + "kspace_key": KspaceKey.MASKED_KSPACE, + } + mri_transforms += [ + ( + GaussianMaskSplitter(**mask_splitter_kwargs, std_scale=mask_split_gaussian_std) + if mask_split_type == MaskSplitterType.GAUSSIAN + else ( + UniformMaskSplitter(**mask_splitter_kwargs) + if mask_split_type == MaskSplitterType.UNIFORM + else HalfMaskSplitterModule( + **{k: v for k, v in mask_splitter_kwargs.items() if k != "ratio"}, + direction=mask_split_half_direction, + ) + ) + ), + DeleteKeys([TransformKey.ACS_MASK]), + ] + + mri_transforms += [ + RenameKeys( + [ + SSLTransformMaskPrefixes.INPUT_ + TransformKey.MASKED_KSPACE, + SSLTransformMaskPrefixes.TARGET_ + TransformKey.MASKED_KSPACE, + ], + ["input_kspace", "kspace"], + ), + DeleteKeys(["masked_kspace", "sampling_mask"]), + ] # Rename keys for SSL engine + + mri_transforms += [ + ComputeImage( + kspace_key=KspaceKey.KSPACE, + target_key=TransformKey.TARGET, + backward_operator=backward_operator, + type_reconstruction=image_recon_type, + ) + ] + + return Compose(mri_transforms) diff --git a/direct/nn/mri_models.py b/direct/nn/mri_models.py index 26ec0286..e984f526 100644 --- a/direct/nn/mri_models.py +++ b/direct/nn/mri_models.py @@ -210,7 +210,7 @@ def _do_iteration( return DoIterationOutput( output_image=( - (output_image, registered_image) + (output_image, registered_image, displacement_field) if (self.ndim == 3 and "registration_model" in self.models) else output_image ), @@ -998,7 +998,7 @@ def reconstruct_volumes( # type: ignore iteration_output = self._do_iteration(data, loss_fns=loss_fns, regularizer_fns=regularizer_fns) output = iteration_output.output_image if "registration_model" in self.models: - output, registered_output = output + output, registered_output, displacement_field = output sampling_mask = iteration_output.sampling_mask if sampling_mask is not None: @@ -1012,6 +1012,7 @@ def reconstruct_volumes( # type: ignore resolution=resolution, complex_axis=self._complex_dim, ) + if "registration_model" in self.models: registered_output_abs = _process_output( registered_output, @@ -1019,6 +1020,10 @@ def reconstruct_volumes( # type: ignore resolution=resolution, complex_axis=self._complex_dim, ) + if resolution is not None: + output_df = T.center_crop(displacement_field, resolution) + else: + output_df = displacement_field if add_target: target_abs = _process_output( @@ -1044,6 +1049,7 @@ def reconstruct_volumes( # type: ignore curr_registration_volume = torch.zeros( *(volume_size, *registered_output_abs.shape[1:]), dtype=registered_output_abs.dtype ) + curr_df_volume = torch.zeros(*(volume_size, *output_df.shape[1:]), dtype=output_df.dtype) curr_mask = ( torch.zeros(*(volume_size, *sampling_mask.shape[1:]), dtype=sampling_mask.dtype) @@ -1061,6 +1067,8 @@ def reconstruct_volumes( # type: ignore curr_registration_volume[instance_counter : instance_counter + output_abs.shape[0], ...] = ( registered_output_abs.cpu() ) + curr_df_volume[instance_counter : instance_counter + output_abs.shape[0], ...] = output_df.cpu() + if sampling_mask is not None: curr_mask[instance_counter : instance_counter + output_abs.shape[0], ...] = sampling_mask.cpu() if add_target: @@ -1088,7 +1096,7 @@ def reconstruct_volumes( # type: ignore del data if "registration_model" in self.models: - curr_volume = (curr_volume, curr_registration_volume) + curr_volume = (curr_volume, curr_registration_volume, curr_df_volume) if add_target and "registration_model" in self.models: curr_target = (curr_target, curr_registration_target) @@ -1123,9 +1131,10 @@ def reconstruct_and_evaluate( # type: ignore ): volume, target, mask, volume_loss_dict, filename = output if isinstance(volume, tuple): - volume, registration_volume = volume + volume, registration_volume, displacement_field = volume else: registration_volume = None + displacement_field = None if isinstance(target, tuple): target, registration_target = target else: @@ -1176,7 +1185,11 @@ def reconstruct_and_evaluate( # type: ignore inf_losses.append(volume_loss_dict) out.append( - (volume if registration_volume is None else (volume, registration_volume), mask, filename), + ( + volume if registration_volume is None else (volume, registration_volume, displacement_field), + mask, + filename, + ), ) # Average loss dict @@ -1235,7 +1248,7 @@ def evaluate( # type: ignore ): volume, target, mask, volume_loss_dict, filename = output if isinstance(volume, tuple): - volume, registration_volume = volume + volume, registration_volume, _ = volume else: registration_volume = None if isinstance(target, tuple): diff --git a/direct/nn/registration/config.py b/direct/nn/registration/config.py index ae9f334c..36395a68 100644 --- a/direct/nn/registration/config.py +++ b/direct/nn/registration/config.py @@ -13,6 +13,8 @@ class RegistrationModelConfig(ModelConfig): warp_num_integration_steps: int = 1 train_end_to_end: bool = False + # TODO: Needs to be defined outside of the config + reg_loss_factor: float = 1.0 # Regularization loss weight factor @dataclass diff --git a/direct/nn/vsharp/config.py b/direct/nn/vsharp/config.py index 58e37d54..b18cbd2f 100644 --- a/direct/nn/vsharp/config.py +++ b/direct/nn/vsharp/config.py @@ -1,53 +1,56 @@ -# Copyright (c) DIRECT Contributors - -from __future__ import annotations - -from dataclasses import dataclass - -from direct.config.defaults import ModelConfig -from direct.nn.types import ActivationType, InitType, ModelName - - -@dataclass -class VSharpNetConfig(ModelConfig): - num_steps: int = 10 - num_steps_dc_gd: int = 8 - image_init: InitType = InitType.SENSE - no_parameter_sharing: bool = True - auxiliary_steps: int = 0 - image_model_architecture: ModelName = ModelName.UNET - initializer_channels: tuple[int, ...] = (32, 32, 64, 64) - initializer_dilations: tuple[int, ...] = (1, 1, 2, 4) - initializer_multiscale: int = 1 - initializer_activation: ActivationType = ActivationType.PRELU - image_resnet_hidden_channels: int = 128 - image_resnet_num_blocks: int = 15 - image_resnet_batchnorm: bool = True - image_resnet_scale: float = 0.1 - image_unet_num_filters: int = 32 - image_unet_num_pool_layers: int = 4 - image_unet_dropout: float = 0.0 - image_didn_hidden_channels: int = 16 - image_didn_num_dubs: int = 6 - image_didn_num_convs_recon: int = 9 - image_conv_hidden_channels: int = 64 - image_conv_n_convs: int = 15 - image_conv_activation: str = ActivationType.RELU - image_conv_batchnorm: bool = False - - -@dataclass -class VSharpNet3DConfig(ModelConfig): - num_steps: int = 8 - num_steps_dc_gd: int = 6 - image_init: InitType = InitType.SENSE - no_parameter_sharing: bool = True - auxiliary_steps: int = -1 - initializer_channels: tuple[int, ...] = (32, 32, 64, 64) - initializer_dilations: tuple[int, ...] = (1, 1, 2, 4) - initializer_multiscale: int = 1 - initializer_activation: ActivationType = ActivationType.PRELU - unet_num_filters: int = 32 - unet_num_pool_layers: int = 4 - unet_dropout: float = 0.0 - unet_norm: bool = False +# Copyright (c) DIRECT Contributors + +from __future__ import annotations + +from dataclasses import dataclass + +from direct.config.defaults import ModelConfig +from direct.nn.types import ActivationType, InitType, ModelName +from direct.nn.vsharp.vsharp import LagrangeMultipliersInitialization + + +@dataclass +class VSharpNetConfig(ModelConfig): + num_steps: int = 10 + num_steps_dc_gd: int = 8 + image_init: InitType = InitType.SENSE + no_parameter_sharing: bool = True + auxiliary_steps: int = 0 + lagrange_initialization: LagrangeMultipliersInitialization = LagrangeMultipliersInitialization.LEARNED + image_model_architecture: ModelName = ModelName.UNET + initializer_channels: tuple[int, ...] = (32, 32, 64, 64) + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4) + initializer_multiscale: int = 1 + initializer_activation: ActivationType = ActivationType.PRELU + image_resnet_hidden_channels: int = 128 + image_resnet_num_blocks: int = 15 + image_resnet_batchnorm: bool = True + image_resnet_scale: float = 0.1 + image_unet_num_filters: int = 32 + image_unet_num_pool_layers: int = 4 + image_unet_dropout: float = 0.0 + image_didn_hidden_channels: int = 16 + image_didn_num_dubs: int = 6 + image_didn_num_convs_recon: int = 9 + image_conv_hidden_channels: int = 64 + image_conv_n_convs: int = 15 + image_conv_activation: str = ActivationType.RELU + image_conv_batchnorm: bool = False + + +@dataclass +class VSharpNet3DConfig(ModelConfig): + num_steps: int = 8 + num_steps_dc_gd: int = 6 + image_init: InitType = InitType.SENSE + no_parameter_sharing: bool = True + auxiliary_steps: int = -1 + lagrange_initialization: LagrangeMultipliersInitialization = LagrangeMultipliersInitialization.LEARNED + initializer_channels: tuple[int, ...] = (32, 32, 64, 64) + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4) + initializer_multiscale: int = 1 + initializer_activation: ActivationType = ActivationType.PRELU + unet_num_filters: int = 32 + unet_num_pool_layers: int = 4 + unet_dropout: float = 0.0 + unet_norm: bool = False diff --git a/direct/nn/vsharp/vsharp.py b/direct/nn/vsharp/vsharp.py index fb417f99..1f92b5bb 100644 --- a/direct/nn/vsharp/vsharp.py +++ b/direct/nn/vsharp/vsharp.py @@ -1,612 +1,640 @@ -# Copyright (c) DIRECT Contributors - -"""This module provides the implementation of vSHARP model. - -Most specifically, vSHARP is the variable Splitting Half-quadratic ADMM algorithm for Reconstruction -of inverse-Problems (vSHARPP) model as presented in [1]_. - - -References ----------- - -.. [1] George Yiasemis et. al. vSHARP: variable Splitting Half-quadratic ADMM algorithm for Reconstruction - of inverse-Problems (2023). https://arxiv.org/abs/2309.09954. - -""" - - -from __future__ import annotations - -from typing import Any, Callable - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - -from direct.constants import COMPLEX_SIZE -from direct.data.transforms import apply_mask, expand_operator, reduce_operator -from direct.nn.get_nn_model_config import ModelName, _get_model_config, _get_relu_activation -from direct.nn.types import ActivationType, InitType -from direct.nn.unet.unet_3d import NormUnetModel3d, UnetModel3d - - -class LagrangeMultipliersInitializer(nn.Module): - """A convolutional neural network model that initializers the Lagrange multiplier of the :class:`vSHARPNet` [1]_. - - More specifically, it produces an initial value for the Lagrange Multiplier based on the zero-filled image: - - .. math:: - - u^0 = \mathcal{G}_{\psi}(x^0). - - References - ---------- - .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction - of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - channels: tuple[int, ...], - dilations: tuple[int, ...], - multiscale_depth: int = 1, - activation: ActivationType = ActivationType.PRELU, - ) -> None: - """Inits :class:`LagrangeMultipliersInitializer`. - - Parameters - ---------- - in_channels : int - Number of input channels. - out_channels : int - Number of output channels. - channels : tuple of ints - Tuple of integers specifying the number of output channels for each convolutional layer in the network. - dilations : tuple of ints - Tuple of integers specifying the dilation factor for each convolutional layer in the network. - multiscale_depth : int - Number of multiscale features to include in the output. Default: 1. - activation : ActivationType - Activation function to use on the output. Default: ActivationType.PRELU. - """ - super().__init__() - - # Define convolutional blocks - self.conv_blocks = nn.ModuleList() - tch = in_channels - for curr_channels, curr_dilations in zip(channels, dilations): - block = nn.Sequential( - nn.ReplicationPad2d(curr_dilations), - nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), - ) - tch = curr_channels - self.conv_blocks.append(block) - - # Define output block - tch = np.sum(channels[-multiscale_depth:]) - block = nn.Conv2d(tch, out_channels, 1, padding=0) - self.out_block = nn.Sequential(block) - - self.multiscale_depth = multiscale_depth - - self.activation = _get_relu_activation(activation) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass of :class:`LagrangeMultipliersInitializer`. - - Parameters - ---------- - x : torch.Tensor - Input tensor of shape (batch_size, in_channels, height, width). - - Returns - ------- - torch.Tensor - Output tensor of shape (batch_size, out_channels, height, width). - """ - - features = [] - for block in self.conv_blocks: - x = F.relu(block(x), inplace=True) - if self.multiscale_depth > 1: - features.append(x) - - if self.multiscale_depth > 1: - x = torch.cat(features[-self.multiscale_depth :], dim=1) - - return self.activation(self.out_block(x)) - - -class VSharpNet(nn.Module): - """ - Variable Splitting Half-quadratic ADMM algorithm for Reconstruction of Parallel MRI [1]_. - - Variable Splitting Half Quadratic VSharpNet is a deep learning model that solves - the augmented Lagrangian derivation of the variable half quadratic splitting problem - using ADMM (Alternating Direction Method of Multipliers). It is specifically designed - for solving inverse problems in magnetic resonance imaging (MRI). - - The VSharpNet model incorporates an iterative optimization algorithm that consists of - three steps: z-step, x-step, and u-step. These steps are detailed mathematically as follows: - - .. math:: - - z^{t+1} = \mathrm{argmin}_{z} \\lambda \mathcal{G}(z) + \\frac{\\rho}{2} || x^{t} - z + - \\frac{u^t}{\\rho} ||_2^2 \\quad \mathrm{[z-step]} - - .. math:: - - x^{t+1} = \mathrm{argmin}_{x} \\frac{1}{2} || \mathcal{A}_{\mathbf{U},\mathbf{S}}(x) - \\tilde{y} ||_2^2 + - \\frac{\\rho}{2} || x - z^{t+1} + \\frac{u^t}{\\rho} ||_2^2 \\quad \mathrm{[x-step]} - - .. math:: - - u^{t+1} = u^t + \\rho (x^{t+1} - z^{t+1}) \\quad \mathrm{[u-step]} - - During the z-step, the model minimizes the augmented Lagrangian function with respect to z, utilizing - DL-based denoisers. In the x-step, it optimizes x by minimizing the data consistency term through - unrolling a gradient descent scheme (DC-GD). The u-step involves updating the Lagrange multiplier u. - These steps are iterated for a specified number of cycles. - - The model includes an initializer for Lagrange multipliers. - - It also allows for outputting auxiliary steps. - - :class:`VSharpNet` is tailored for 2D MRI data reconstruction. - - References - ---------- - .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction - of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954. - """ - - def __init__( - self, - forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], - backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], - num_steps: int, - num_steps_dc_gd: int, - image_init: InitType = InitType.SENSE, - no_parameter_sharing: bool = True, - image_model_architecture: ModelName = ModelName.UNET, - initializer_channels: tuple[int, ...] = (32, 32, 64, 64), - initializer_dilations: tuple[int, ...] = (1, 1, 2, 4), - initializer_multiscale: int = 1, - initializer_activation: ActivationType = ActivationType.PRELU, - auxiliary_steps: int = 0, - **kwargs, - ) -> None: - """Inits :class:`VSharpNet`. - - Parameters - ---------- - forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] - Forward operator function. - backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] - Backward operator function. - num_steps : int - Number of steps in the ADMM algorithm. - num_steps_dc_gd : int - Number of steps in the Data Consistency using Gradient Descent step of ADMM. - image_init : str - Image initialization method. Default: 'sense'. - no_parameter_sharing : bool - Flag indicating whether parameter sharing is enabled in the denoiser blocks. - image_model_architecture : ModelName - Image model architecture. Default: ModelName.UNET. - initializer_channels : tuple[int, ...] - Tuple of integers specifying the number of output channels for each convolutional layer in the - Lagrange multiplier initializer. Default: (32, 32, 64, 64). - initializer_dilations : tuple[int, ...] - Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier - initializer. Default: (1, 1, 2, 4). - initializer_multiscale : int - Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1. - initializer_activation : ActivationType - Activation type for the Lagrange multiplier initializer. Default: ActivationType.PRELU. - auxiliary_steps : int - Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`. - If -1, it uses all steps. If I, the last I steps will be used. - **kwargs: Additional keyword arguments. - Can be `model_name` or `image_model_` where `` represent parameters of the selected - image model architecture beyond the standard parameters. - Depending on the `image_model_architecture` chosen, different kwargs will be applicable. - """ - # pylint: disable=too-many-locals - super().__init__() - for extra_key in kwargs: - if extra_key != "model_name" and not extra_key.startswith("image_"): - raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") - self.num_steps = num_steps - self.num_steps_dc_gd = num_steps_dc_gd - - self.no_parameter_sharing = no_parameter_sharing - - image_model, image_model_kwargs = _get_model_config( - image_model_architecture, - in_channels=COMPLEX_SIZE * 3, - out_channels=COMPLEX_SIZE, - **{k.replace("image_", ""): v for (k, v) in kwargs.items() if "image_" in k}, - ) - - self.denoiser_blocks = nn.ModuleList() - for _ in range(num_steps if self.no_parameter_sharing else 1): - self.denoiser_blocks.append(image_model(**image_model_kwargs)) - - self.initializer = LagrangeMultipliersInitializer( - in_channels=COMPLEX_SIZE, - out_channels=COMPLEX_SIZE, - channels=initializer_channels, - dilations=initializer_dilations, - multiscale_depth=initializer_multiscale, - activation=initializer_activation, - ) - - self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True)) - nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0) - - self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True)) - nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0) - - self.forward_operator = forward_operator - self.backward_operator = backward_operator - - if image_init not in [InitType.SENSE, InitType.ZERO_FILLED]: - raise ValueError( - f"Unknown image_initialization. Expected `InitType.SENSE` or `InitType.ZERO_FILLED`. " - f"Got {image_init}." - ) - - self.image_init = image_init - - if not ((auxiliary_steps == -1) or (0 < auxiliary_steps <= num_steps)): - raise ValueError( - f"Number of auxiliary steps should be -1 to use all steps or a positive" - f" integer <= than `num_steps`. Received {auxiliary_steps}." - ) - if auxiliary_steps == -1: - self.auxiliary_steps = list(range(num_steps)) - else: - self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps)) - - self._coil_dim = 1 - self._complex_dim = -1 - self._spatial_dims = (2, 3) - - def forward( - self, - masked_kspace: torch.Tensor, - sensitivity_map: torch.Tensor, - sampling_mask: torch.Tensor, - ) -> list[torch.Tensor]: - """Computes forward pass of :class:`VSharpNet`. - - Parameters - ---------- - masked_kspace: torch.Tensor - Masked k-space of shape (N, coil, height, width, complex=2). - sensitivity_map: torch.Tensor - Sensitivity map of shape (N, coil, height, width, complex=2). - sampling_mask: torch.Tensor - Sampling mask of shape (N, 1, height, width, 1). - - Returns - ------- - out : list of torch.Tensors - List of output images of shape (N, height, width, complex=2). - """ - out = [] - if self.image_init == InitType.SENSE: - x = reduce_operator( - coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), - sensitivity_map=sensitivity_map, - dim=self._coil_dim, - ) - else: - x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim) - - z = x.clone() - - u = self.initializer(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) - - for admm_step in range(self.num_steps): - z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0]( - torch.cat( - [z, x, u / self.rho[admm_step]], - dim=self._complex_dim, - ).permute(0, 3, 1, 2) - ).permute(0, 2, 3, 1) - - for dc_gd_step in range(self.num_steps_dc_gd): - dc = apply_mask( - self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims) - - masked_kspace, - sampling_mask, - return_mask=False, - ) - dc = self.backward_operator(dc, dim=self._spatial_dims) - dc = reduce_operator(dc, sensitivity_map, self._coil_dim) - - x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u) - - if admm_step in self.auxiliary_steps: - out.append(x) - - u = u + self.rho[admm_step] * (x - z) - - return out - - -class LagrangeMultipliersInitializer3D(torch.nn.Module): - """A convolutional neural network model that initializes the Lagrange multiplier of :class:`VSharpNet3D`. - - This is an extension to 3D data of :class:`LagrangeMultipliersInitializer`. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - channels: tuple[int, ...], - dilations: tuple[int, ...], - multiscale_depth: int = 1, - activation: ActivationType = ActivationType.PRELU, - ): - """Initializes LagrangeMultipliersInitializer3D. - - Parameters - ---------- - in_channels : int - Number of input channels. - out_channels : int - Number of output channels. - channels : tuple of ints - Tuple of integers specifying the number of output channels for each convolutional layer in the network. - dilations : tuple of ints - Tuple of integers specifying the dilation factor for each convolutional layer in the network. - multiscale_depth : int - Number of multiscale features to include in the output. Default: 1. - activation : ActivationType - Activation function to use on the output. Default: ActivationType.PRELU. - """ - super().__init__() - - # Define convolutional blocks - self.conv_blocks = nn.ModuleList() - tch = in_channels - for curr_channels, curr_dilations in zip(channels, dilations): - block = nn.Sequential( - nn.ReplicationPad3d(curr_dilations), - nn.Conv3d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), - ) - tch = curr_channels - self.conv_blocks.append(block) - - # Define output block - tch = np.sum(channels[-multiscale_depth:]) - block = nn.Conv3d(tch, out_channels, 1, padding=0) - self.out_block = nn.Sequential(block) - - self.multiscale_depth = multiscale_depth - self.activation = _get_relu_activation(activation) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass of :class:`LagrangeMultipliersInitializer3D`. - - Parameters - ---------- - x : torch.Tensor - Input tensor of shape (batch_size, in_channels, z, x, y). - - Returns - ------- - torch.Tensor - Output tensor of shape (batch_size, out_channels, z, x, y). - """ - - features = [] - for block in self.conv_blocks: - x = F.relu(block(x), inplace=True) - if self.multiscale_depth > 1: - features.append(x) - - if self.multiscale_depth > 1: - x = torch.cat(features[-self.multiscale_depth :], dim=1) - - return self.activation(self.out_block(x)) - - -class VSharpNet3D(nn.Module): - """VharpNet 3D version using 3D U-Nets as denoisers. - - This is an extension to 3D of :class:`VSharpNet`. For the original paper refer to [1]_. - - References - ---------- - .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction - of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954. - """ - - def __init__( - self, - forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], - backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], - num_steps: int, - num_steps_dc_gd: int, - image_init: InitType = InitType.SENSE, - no_parameter_sharing: bool = True, - initializer_channels: tuple[int, ...] = (32, 32, 64, 64), - initializer_dilations: tuple[int, ...] = (1, 1, 2, 4), - initializer_multiscale: int = 1, - initializer_activation: ActivationType = ActivationType.PRELU, - auxiliary_steps: int = -1, - unet_num_filters: int = 32, - unet_num_pool_layers: int = 4, - unet_dropout: float = 0.0, - unet_norm: bool = False, - **kwargs, - ): - """Inits :class:`VSharpNet3D`. - - Parameters - ---------- - forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] - Forward operator function. - backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] - Backward operator function. - num_steps : int - Number of steps in the ADMM algorithm. - num_steps_dc_gd : int - Number of steps in the Data Consistency using Gradient Descent step of ADMM. - image_init : str - Image initialization method. Default: 'sense'. - no_parameter_sharing : bool - Flag indicating whether parameter sharing is enabled in the denoiser blocks. - initializer_channels : tuple[int, ...] - Tuple of integers specifying the number of output channels for each convolutional layer in the - Lagrange multiplier initializer. Default: (32, 32, 64, 64). - initializer_dilations : tuple[int, ...] - Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier - initializer. Default: (1, 1, 2, 4). - initializer_multiscale : int - Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1. - initializer_activation : ActivationType - Activation type for the Lagrange multiplier initializer. Default: ActivationType.PReLU. - auxiliary_steps : int - Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`. - If -1, it uses all steps. If I, the last I steps will be used. - unet_num_filters : int - U-Net denoisers number of output channels of the first convolutional layer. Default: 32. - unet_num_pool_layers : int - U-Net denoisers number of down-sampling and up-sampling layers (depth). Default: 4. - unet_dropout : float - U-Net denoisers dropout probability. Default: 0.0 - unet_norm : bool - Whether to use normalized U-Net as denoiser or not. Default: False. - **kwargs: Additional keyword arguments. - Can be `model_name`. - """ - # pylint: disable=too-many-locals - super().__init__() - for extra_key in kwargs: - if extra_key != "model_name": - raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") - self.num_steps = num_steps - self.num_steps_dc_gd = num_steps_dc_gd - - self.no_parameter_sharing = no_parameter_sharing - - self.denoiser_blocks = nn.ModuleList() - for _ in range(num_steps if self.no_parameter_sharing else 1): - self.denoiser_blocks.append( - (UnetModel3d if not unet_norm else NormUnetModel3d)( - in_channels=COMPLEX_SIZE * 3, - out_channels=COMPLEX_SIZE, - num_filters=unet_num_filters, - num_pool_layers=unet_num_pool_layers, - dropout_probability=unet_dropout, - ) - ) - - self.initializer = LagrangeMultipliersInitializer3D( - in_channels=COMPLEX_SIZE, - out_channels=COMPLEX_SIZE, - channels=initializer_channels, - dilations=initializer_dilations, - multiscale_depth=initializer_multiscale, - activation=initializer_activation, - ) - - self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True)) - nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0) - - self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True)) - nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0) - - self.forward_operator = forward_operator - self.backward_operator = backward_operator - - if image_init not in ["sense", "zero_filled"]: - raise ValueError(f"Unknown image_initialization. Expected 'sense' or 'zero_filled'. " f"Got {image_init}.") - - self.image_init = image_init - - if not (auxiliary_steps == -1 or 0 < auxiliary_steps <= num_steps): - raise ValueError( - f"Number of auxiliary steps should be -1 to use all steps or a positive" - f" integer <= than `num_steps`. Received {auxiliary_steps}." - ) - if auxiliary_steps == -1: - self.auxiliary_steps = list(range(num_steps)) - else: - self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps)) - - self._coil_dim = 1 - self._complex_dim = -1 - self._spatial_dims = (3, 4) - - def forward( - self, - masked_kspace: torch.Tensor, - sensitivity_map: torch.Tensor, - sampling_mask: torch.Tensor, - ) -> list[torch.Tensor]: - """Computes forward pass of :class:`VSharpNet3D`. - - Parameters - ---------- - masked_kspace : torch.Tensor - Masked k-space of shape (N, coil, slice, height, width, complex=2). - sensitivity_map : torch.Tensor - Sensitivity map of shape (N, coil, slice, height, width, complex=2). - sampling_mask : torch.Tensor - Sampling mask of shape (N, 1, 1 or slice, height, width, 1). - - Returns - ------- - out : list of torch.Tensors - List of output images each of shape (N, slice, height, width, complex=2). - """ - out = [] - if self.image_init == InitType.SENSE: - x = reduce_operator( - coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), - sensitivity_map=sensitivity_map, - dim=self._coil_dim, - ) - else: - x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim) - - z = x.clone() - - u = self.initializer(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1) - - for admm_step in range(self.num_steps): - z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0]( - torch.cat( - [z, x, u / self.rho[admm_step]], - dim=self._complex_dim, - ).permute(0, 4, 1, 2, 3) - ).permute(0, 2, 3, 4, 1) - - for dc_gd_step in range(self.num_steps_dc_gd): - dc = apply_mask( - self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims) - - masked_kspace, - sampling_mask, - return_mask=False, - ) - dc = self.backward_operator(dc, dim=self._spatial_dims) - dc = reduce_operator(dc, sensitivity_map, self._coil_dim) - - x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u) - - if admm_step in self.auxiliary_steps: - out.append(x) - - u = u + self.rho[admm_step] * (x - z) - - return out +# Copyright (c) DIRECT Contributors + +"""This module provides the implementation of vSHARP model. + +Most specifically, vSHARP is the variable Splitting Half-quadratic ADMM algorithm for Reconstruction +of inverse-Problems (vSHARPP) model as presented in [1]_. + + +References +---------- +.. [1] Yiasemis, G., Moriakov, N., Sonke, J.-J., Teuwen, J.: vSHARP: Variable Splitting Half-quadratic ADMM algorithm + for reconstruction of inverse-problems. Magnetic Resonance Imaging. 110266 (2024). + https://doi.org/10.1016/j.mri.2024.110266. + +""" + + +from __future__ import annotations + +from typing import Any, Callable + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from direct.constants import COMPLEX_SIZE +from direct.data.transforms import apply_mask, expand_operator, reduce_operator +from direct.nn.get_nn_model_config import ModelName, _get_model_config, _get_relu_activation +from direct.nn.types import ActivationType, DirectEnum, InitType +from direct.nn.unet.unet_3d import NormUnetModel3d, UnetModel3d + + +class LagrangeMultipliersInitialization(DirectEnum): + LEARNED = "learned" + ZEROS = "zeros" + + +class LagrangeMultipliersInitializer(nn.Module): + """A convolutional neural network model that initializers the Lagrange multiplier of the :class:`vSHARPNet` [1]_. + + More specifically, it produces an initial value for the Lagrange Multiplier based on the zero-filled image: + + .. math:: + + u^0 = \mathcal{G}_{\psi}(x^0). + + References + ---------- + .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction + of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + channels: tuple[int, ...], + dilations: tuple[int, ...], + multiscale_depth: int = 1, + activation: ActivationType = ActivationType.PRELU, + ) -> None: + """Inits :class:`LagrangeMultipliersInitializer`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + channels : tuple of ints + Tuple of integers specifying the number of output channels for each convolutional layer in the network. + dilations : tuple of ints + Tuple of integers specifying the dilation factor for each convolutional layer in the network. + multiscale_depth : int + Number of multiscale features to include in the output. Default: 1. + activation : ActivationType + Activation function to use on the output. Default: ActivationType.PRELU. + """ + super().__init__() + + # Define convolutional blocks + self.conv_blocks = nn.ModuleList() + tch = in_channels + for curr_channels, curr_dilations in zip(channels, dilations): + block = nn.Sequential( + nn.ReplicationPad2d(curr_dilations), + nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), + ) + tch = curr_channels + self.conv_blocks.append(block) + + # Define output block + tch = np.sum(channels[-multiscale_depth:]) + block = nn.Conv2d(tch, out_channels, 1, padding=0) + self.out_block = nn.Sequential(block) + + self.multiscale_depth = multiscale_depth + + self.activation = _get_relu_activation(activation) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`LagrangeMultipliersInitializer`. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, in_channels, height, width). + + Returns + ------- + torch.Tensor + Output tensor of shape (batch_size, out_channels, height, width). + """ + + features = [] + for block in self.conv_blocks: + x = F.relu(block(x), inplace=True) + if self.multiscale_depth > 1: + features.append(x) + + if self.multiscale_depth > 1: + x = torch.cat(features[-self.multiscale_depth :], dim=1) + + return self.activation(self.out_block(x)) + + +class VSharpNet(nn.Module): + """ + Variable Splitting Half-quadratic ADMM algorithm for Reconstruction of Parallel MRI [1]_. + + Variable Splitting Half Quadratic VSharpNet is a deep learning model that solves + the augmented Lagrangian derivation of the variable half quadratic splitting problem + using ADMM (Alternating Direction Method of Multipliers). It is specifically designed + for solving inverse problems in magnetic resonance imaging (MRI). + + The VSharpNet model incorporates an iterative optimization algorithm that consists of + three steps: z-step, x-step, and u-step. These steps are detailed mathematically as follows: + + .. math:: + + z^{t+1} = \mathrm{argmin}_{z} \\lambda \mathcal{G}(z) + \\frac{\\rho}{2} || x^{t} - z + + \\frac{u^t}{\\rho} ||_2^2 \\quad \mathrm{[z-step]} + + .. math:: + + x^{t+1} = \mathrm{argmin}_{x} \\frac{1}{2} || \mathcal{A}_{\mathbf{U},\mathbf{S}}(x) - \\tilde{y} ||_2^2 + + \\frac{\\rho}{2} || x - z^{t+1} + \\frac{u^t}{\\rho} ||_2^2 \\quad \mathrm{[x-step]} + + .. math:: + + u^{t+1} = u^t + \\rho (x^{t+1} - z^{t+1}) \\quad \mathrm{[u-step]} + + During the z-step, the model minimizes the augmented Lagrangian function with respect to z, utilizing + DL-based denoisers. In the x-step, it optimizes x by minimizing the data consistency term through + unrolling a gradient descent scheme (DC-GD). The u-step involves updating the Lagrange multiplier u. + These steps are iterated for a specified number of cycles. + + The model includes an initializer for Lagrange multipliers. + + It also allows for outputting auxiliary steps. + + :class:`VSharpNet` is tailored for 2D MRI data reconstruction. + + References + ---------- + .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction + of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954. + """ + + def __init__( + self, + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + num_steps: int, + num_steps_dc_gd: int, + image_init: InitType = InitType.SENSE, + no_parameter_sharing: bool = True, + image_model_architecture: ModelName = ModelName.UNET, + initializer_channels: tuple[int, ...] = (32, 32, 64, 64), + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4), + initializer_multiscale: int = 1, + initializer_activation: ActivationType = ActivationType.PRELU, + auxiliary_steps: int = 0, + lagrange_initialization: LagrangeMultipliersInitialization = LagrangeMultipliersInitialization.LEARNED, + **kwargs, + ) -> None: + """Inits :class:`VSharpNet`. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + num_steps : int + Number of steps in the ADMM algorithm. + num_steps_dc_gd : int + Number of steps in the Data Consistency using Gradient Descent step of ADMM. + image_init : str + Image initialization method. Default: 'sense'. + no_parameter_sharing : bool + Flag indicating whether parameter sharing is enabled in the denoiser blocks. + image_model_architecture : ModelName + Image model architecture. Default: ModelName.UNET. + initializer_channels : tuple[int, ...] + Tuple of integers specifying the number of output channels for each convolutional layer in the + Lagrange multiplier initializer. Default: (32, 32, 64, 64). + initializer_dilations : tuple[int, ...] + Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier + initializer. Default: (1, 1, 2, 4). + initializer_multiscale : int + Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1. + initializer_activation : ActivationType + Activation type for the Lagrange multiplier initializer. Default: ActivationType.PRELU. + auxiliary_steps : int + Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`. + If -1, it uses all steps. If I, the last I steps will be used. + lagrange_initialization : LagrangeMultipliersInitialization + Lagrange multiplier initialization method. Can be LagrangeMultipliersInitialization.LEARNED or + LagrangeMultipliersInitialization.ZEROS, corresponding to learned initialization or zero initialization. + Default: LagrangeMultipliersInitialization.LEARNED. + **kwargs: Additional keyword arguments. + Can be `model_name` or `image_model_` where `` represent parameters of the selected + image model architecture beyond the standard parameters. + Depending on the `image_model_architecture` chosen, different kwargs will be applicable. + """ + # pylint: disable=too-many-locals + super().__init__() + for extra_key in kwargs: + if extra_key != "model_name" and not extra_key.startswith("image_"): + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + self.num_steps = num_steps + self.num_steps_dc_gd = num_steps_dc_gd + + self.no_parameter_sharing = no_parameter_sharing + + image_model, image_model_kwargs = _get_model_config( + image_model_architecture, + in_channels=COMPLEX_SIZE * 3, + out_channels=COMPLEX_SIZE, + **{k.replace("image_", ""): v for (k, v) in kwargs.items() if "image_" in k}, + ) + + self.denoiser_blocks = nn.ModuleList() + for _ in range(num_steps if self.no_parameter_sharing else 1): + self.denoiser_blocks.append(image_model(**image_model_kwargs)) + + self.lagrange_initialization = lagrange_initialization + if lagrange_initialization == LagrangeMultipliersInitialization.LEARNED: + self.initializer = LagrangeMultipliersInitializer( + in_channels=COMPLEX_SIZE, + out_channels=COMPLEX_SIZE, + channels=initializer_channels, + dilations=initializer_dilations, + multiscale_depth=initializer_multiscale, + activation=initializer_activation, + ) + + self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True)) + nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0) + + self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True)) + nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0) + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + if image_init not in [InitType.SENSE, InitType.ZERO_FILLED]: + raise ValueError( + f"Unknown image_initialization. Expected `InitType.SENSE` or `InitType.ZERO_FILLED`. " + f"Got {image_init}." + ) + + self.image_init = image_init + + if not ((auxiliary_steps == -1) or (0 < auxiliary_steps <= num_steps)): + raise ValueError( + f"Number of auxiliary steps should be -1 to use all steps or a positive" + f" integer <= than `num_steps`. Received {auxiliary_steps}." + ) + if auxiliary_steps == -1: + self.auxiliary_steps = list(range(num_steps)) + else: + self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps)) + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (2, 3) + + def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: torch.Tensor, + sampling_mask: torch.Tensor, + ) -> list[torch.Tensor]: + """Computes forward pass of :class:`VSharpNet`. + + Parameters + ---------- + masked_kspace: torch.Tensor + Masked k-space of shape (N, coil, height, width, complex=2). + sensitivity_map: torch.Tensor + Sensitivity map of shape (N, coil, height, width, complex=2). + sampling_mask: torch.Tensor + Sampling mask of shape (N, 1, height, width, 1). + + Returns + ------- + out : list of torch.Tensors + List of output images of shape (N, height, width, complex=2). + """ + out = [] + if self.image_init == InitType.SENSE: + x = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + else: + x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim) + + z = x.clone() + + if self.lagrange_initialization == LagrangeMultipliersInitialization.LEARNED: + u = self.initializer(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + else: + u = torch.zeros_like(x).to(x.device) + + for admm_step in range(self.num_steps): + z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0]( + torch.cat( + [z, x, u / self.rho[admm_step]], + dim=self._complex_dim, + ).permute(0, 3, 1, 2) + ).permute(0, 2, 3, 1) + + for dc_gd_step in range(self.num_steps_dc_gd): + dc = apply_mask( + self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims) + - masked_kspace, + sampling_mask, + return_mask=False, + ) + dc = self.backward_operator(dc, dim=self._spatial_dims) + dc = reduce_operator(dc, sensitivity_map, self._coil_dim) + + x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u) + + if admm_step in self.auxiliary_steps: + out.append(x) + + u = u + self.rho[admm_step] * (x - z) + + return out + + +class LagrangeMultipliersInitializer3D(torch.nn.Module): + """A convolutional neural network model that initializes the Lagrange multiplier of :class:`VSharpNet3D`. + + This is an extension to 3D data of :class:`LagrangeMultipliersInitializer`. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + channels: tuple[int, ...], + dilations: tuple[int, ...], + multiscale_depth: int = 1, + activation: ActivationType = ActivationType.PRELU, + ): + """Initializes LagrangeMultipliersInitializer3D. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + channels : tuple of ints + Tuple of integers specifying the number of output channels for each convolutional layer in the network. + dilations : tuple of ints + Tuple of integers specifying the dilation factor for each convolutional layer in the network. + multiscale_depth : int + Number of multiscale features to include in the output. Default: 1. + activation : ActivationType + Activation function to use on the output. Default: ActivationType.PRELU. + """ + super().__init__() + + # Define convolutional blocks + self.conv_blocks = nn.ModuleList() + tch = in_channels + for curr_channels, curr_dilations in zip(channels, dilations): + block = nn.Sequential( + nn.ReplicationPad3d(curr_dilations), + nn.Conv3d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), + ) + tch = curr_channels + self.conv_blocks.append(block) + + # Define output block + tch = np.sum(channels[-multiscale_depth:]) + block = nn.Conv3d(tch, out_channels, 1, padding=0) + self.out_block = nn.Sequential(block) + + self.multiscale_depth = multiscale_depth + self.activation = _get_relu_activation(activation) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`LagrangeMultipliersInitializer3D`. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, in_channels, z, x, y). + + Returns + ------- + torch.Tensor + Output tensor of shape (batch_size, out_channels, z, x, y). + """ + + features = [] + for block in self.conv_blocks: + x = F.relu(block(x), inplace=True) + if self.multiscale_depth > 1: + features.append(x) + + if self.multiscale_depth > 1: + x = torch.cat(features[-self.multiscale_depth :], dim=1) + + return self.activation(self.out_block(x)) + + +class VSharpNet3D(nn.Module): + """VharpNet 3D version using 3D U-Nets as denoisers. + + This is an extension to 3D of :class:`VSharpNet`. For the original paper refer to [1]_. + + References + ---------- + .. [1] George Yiasemis et al., "VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction + of Inverse Problems" (2023). https://arxiv.org/abs/2309.09954. + """ + + def __init__( + self, + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], + num_steps: int, + num_steps_dc_gd: int, + image_init: InitType = InitType.SENSE, + no_parameter_sharing: bool = True, + initializer_channels: tuple[int, ...] = (32, 32, 64, 64), + initializer_dilations: tuple[int, ...] = (1, 1, 2, 4), + initializer_multiscale: int = 1, + initializer_activation: ActivationType = ActivationType.PRELU, + auxiliary_steps: int = -1, + lagrange_initialization: LagrangeMultipliersInitialization = LagrangeMultipliersInitialization.LEARNED, + unet_num_filters: int = 32, + unet_num_pool_layers: int = 4, + unet_dropout: float = 0.0, + unet_norm: bool = False, + **kwargs, + ): + """Inits :class:`VSharpNet3D`. + + Parameters + ---------- + forward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Forward operator function. + backward_operator : Callable[[tuple[Any, ...]], torch.Tensor] + Backward operator function. + num_steps : int + Number of steps in the ADMM algorithm. + num_steps_dc_gd : int + Number of steps in the Data Consistency using Gradient Descent step of ADMM. + image_init : str + Image initialization method. Default: 'sense'. + no_parameter_sharing : bool + Flag indicating whether parameter sharing is enabled in the denoiser blocks. + initializer_channels : tuple[int, ...] + Tuple of integers specifying the number of output channels for each convolutional layer in the + Lagrange multiplier initializer. Default: (32, 32, 64, 64). + initializer_dilations : tuple[int, ...] + Tuple of integers specifying the dilation factor for each convolutional layer in the Lagrange multiplier + initializer. Default: (1, 1, 2, 4). + initializer_multiscale : int + Number of multiscale features to include in the Lagrange multiplier initializer output. Default: 1. + initializer_activation : ActivationType + Activation type for the Lagrange multiplier initializer. Default: ActivationType.PReLU. + auxiliary_steps : int + Number of auxiliary steps to output. Can be -1 or a positive integer lower or equal to `num_steps`. + If -1, it uses all steps. If I, the last I steps will be used. + lagrange_initialization : LagrangeMultipliersInitialization + Lagrange multiplier initialization method. Can be LagrangeMultipliersInitialization.LEARNED or + LagrangeMultipliersInitialization.ZEROS, corresponding to learned initialization or zero initialization. + Default: LagrangeMultipliersInitialization.LEARNED. + unet_num_filters : int + U-Net denoisers number of output channels of the first convolutional layer. Default: 32. + unet_num_pool_layers : int + U-Net denoisers number of down-sampling and up-sampling layers (depth). Default: 4. + unet_dropout : float + U-Net denoisers dropout probability. Default: 0.0 + unet_norm : bool + Whether to use normalized U-Net as denoiser or not. Default: False. + **kwargs: Additional keyword arguments. + Can be `model_name`. + """ + # pylint: disable=too-many-locals + super().__init__() + for extra_key in kwargs: + if extra_key != "model_name": + raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.") + self.num_steps = num_steps + self.num_steps_dc_gd = num_steps_dc_gd + + self.no_parameter_sharing = no_parameter_sharing + + self.denoiser_blocks = nn.ModuleList() + for _ in range(num_steps if self.no_parameter_sharing else 1): + self.denoiser_blocks.append( + (UnetModel3d if not unet_norm else NormUnetModel3d)( + in_channels=COMPLEX_SIZE * 3, + out_channels=COMPLEX_SIZE, + num_filters=unet_num_filters, + num_pool_layers=unet_num_pool_layers, + dropout_probability=unet_dropout, + ) + ) + + self.lagrange_initialization = lagrange_initialization + if lagrange_initialization == LagrangeMultipliersInitialization.LEARNED: + self.initializer = LagrangeMultipliersInitializer3D( + in_channels=COMPLEX_SIZE, + out_channels=COMPLEX_SIZE, + channels=initializer_channels, + dilations=initializer_dilations, + multiscale_depth=initializer_multiscale, + activation=initializer_activation, + ) + + self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True)) + nn.init.trunc_normal_(self.learning_rate_eta, 0.0, 1.0, 0.0) + + self.rho = nn.Parameter(torch.ones(num_steps, requires_grad=True)) + nn.init.trunc_normal_(self.rho, 0, 0.1, 0.0) + + self.forward_operator = forward_operator + self.backward_operator = backward_operator + + if image_init not in [InitType.SENSE, InitType.ZERO_FILLED]: + raise ValueError( + f"Unknown image_initialization. Expected `InitType.SENSE` or `InitType.ZERO_FILLED`. " + f"Got {image_init}." + ) + + self.image_init = image_init + + if not (auxiliary_steps == -1 or 0 < auxiliary_steps <= num_steps): + raise ValueError( + f"Number of auxiliary steps should be -1 to use all steps or a positive" + f" integer <= than `num_steps`. Received {auxiliary_steps}." + ) + if auxiliary_steps == -1: + self.auxiliary_steps = list(range(num_steps)) + else: + self.auxiliary_steps = list(range(num_steps - min(auxiliary_steps, num_steps), num_steps)) + + self._coil_dim = 1 + self._complex_dim = -1 + self._spatial_dims = (3, 4) + + def forward( + self, + masked_kspace: torch.Tensor, + sensitivity_map: torch.Tensor, + sampling_mask: torch.Tensor, + ) -> list[torch.Tensor]: + """Computes forward pass of :class:`VSharpNet3D`. + + Parameters + ---------- + masked_kspace : torch.Tensor + Masked k-space of shape (N, coil, slice, height, width, complex=2). + sensitivity_map : torch.Tensor + Sensitivity map of shape (N, coil, slice, height, width, complex=2). + sampling_mask : torch.Tensor + Sampling mask of shape (N, 1, 1 or slice, height, width, 1). + + Returns + ------- + out : list of torch.Tensors + List of output images each of shape (N, slice, height, width, complex=2). + """ + out = [] + if self.image_init == InitType.SENSE: + x = reduce_operator( + coil_data=self.backward_operator(masked_kspace, dim=self._spatial_dims), + sensitivity_map=sensitivity_map, + dim=self._coil_dim, + ) + else: + x = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim) + + z = x.clone() + + if self.lagrange_initialization == LagrangeMultipliersInitialization.LEARNED: + u = self.initializer(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1) + else: + u = torch.zeros_like(x).to(x.device) + + for admm_step in range(self.num_steps): + z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0]( + torch.cat( + [z, x, u / self.rho[admm_step]], + dim=self._complex_dim, + ).permute(0, 4, 1, 2, 3) + ).permute(0, 2, 3, 4, 1) + + for dc_gd_step in range(self.num_steps_dc_gd): + dc = apply_mask( + self.forward_operator(expand_operator(x, sensitivity_map, self._coil_dim), dim=self._spatial_dims) + - masked_kspace, + sampling_mask, + return_mask=False, + ) + dc = self.backward_operator(dc, dim=self._spatial_dims) + dc = reduce_operator(dc, sensitivity_map, self._coil_dim) + + x = x - self.learning_rate_eta[dc_gd_step] * (dc + self.rho[admm_step] * (x - z) + u) + + if admm_step in self.auxiliary_steps: + out.append(x) + + u = u + self.rho[admm_step] * (x - z) + + return out diff --git a/direct/nn/vsharp/vsharp_engine.py b/direct/nn/vsharp/vsharp_engine.py index 9efb16e4..51578f93 100644 --- a/direct/nn/vsharp/vsharp_engine.py +++ b/direct/nn/vsharp/vsharp_engine.py @@ -6,9 +6,9 @@ References ---------- -.. [1] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and - Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). - https://doi.org/10.48550/arXiv.2311.15856. +.. [1] Yiasemis, G., Moriakov, N., Sonke, J.-J., Teuwen, J.: vSHARP: Variable Splitting Half-quadratic ADMM algorithm + for reconstruction of inverse-problems. Magnetic Resonance Imaging. 110266 (2024). + https://doi.org/10.1016/j.mri.2024.110266. .. [2] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856. @@ -165,13 +165,21 @@ def _do_iteration( if shape == registered_image.shape else data["reference_image"].tile((1, registered_image.shape[1], *([1] * len(shape[1:])))) ), + weight=self.cfg.additional_models.registration_model.reg_loss_factor, ) + + if "displacement_field" in data: + target_displacement_field = data["displacement_field"] + else: + target_displacement_field = None + loss_dict = self.compute_loss_on_data( loss_dict, loss_fns, data, output_displacement_field=displacement_field, - target_displacement_field=data["displacement_field"], + target_displacement_field=target_displacement_field, + weight=self.cfg.additional_models.registration_model.reg_loss_factor, ) loss = sum(loss_dict.values()) # type: ignore @@ -183,7 +191,11 @@ def _do_iteration( output_image = output_images[-1] return DoIterationOutput( - output_image=(output_image, registered_image) if "registration_model" in self.models else output_image, + output_image=( + (output_image, registered_image, displacement_field) + if "registration_model" in self.models + else output_image + ), sensitivity_map=data["sensitivity_map"], sampling_mask=data["sampling_mask"], data_dict={**loss_dict}, diff --git a/direct/train.py b/direct/train.py index 31207ba0..b3400c1e 100644 --- a/direct/train.py +++ b/direct/train.py @@ -1,332 +1,332 @@ -# coding=utf-8 -# Copyright (c) DIRECT Contributors -import argparse -import functools -import logging -import os -import pathlib -import sys -import urllib.parse -from collections import defaultdict -from typing import Callable, Dict, List, Optional, Union - -import numpy as np -import torch -from omegaconf import DictConfig - -from direct.cli.utils import check_train_val -from direct.common.subsample import build_masking_function -from direct.data.datasets import build_dataset_from_input -from direct.data.lr_scheduler import WarmupMultiStepLR -from direct.data.mri_transforms import build_mri_transforms -from direct.environment import setup_training_environment -from direct.launch import launch -from direct.types import PathOrString -from direct.utils import dict_flatten, remove_keys, set_all_seeds, str_to_class -from direct.utils.dataset import get_filenames_for_datasets_from_config -from direct.utils.io import check_is_valid_url, read_json - -logger = logging.getLogger(__name__) - - -def parse_noise_dict(noise_dict: dict, percentile: float = 1.0, multiplier: float = 1.0): - logger.info("Parsing noise dictionary...") - output: Dict = defaultdict(dict) - for filename in noise_dict: - data_per_volume = noise_dict[filename] - for slice_no in data_per_volume: - curr_data = data_per_volume[slice_no] - if percentile != 1.0: - lower_clip = np.percentile(curr_data, 100 * (1 - percentile)) - upper_clip = np.percentile(curr_data, 100 * percentile) - curr_data = np.clip(curr_data, lower_clip, upper_clip) - - output[filename][int(slice_no)] = ( - curr_data * multiplier - ) ** 2 # np.asarray(curr_data) * multiplier# (np.clip(curr_data, lower_clip, upper_clip) * multiplier) ** 2 - - return output - - -def get_root_of_file(filename: PathOrString): - """Get the root directory of the file or URL to file. - - Examples - -------- - >>> get_root_of_file('/mnt/archive/data.txt') - >>> /mnt/archive - >>> get_root_of_file('https://aiforoncology.nl/people') - >>> https://aiforoncology.nl/ - - Parameters - ---------- - filename: pathlib.Path or str - - Returns - ------- - pathlib.Path or str - """ - if check_is_valid_url(str(filename)): - filename = urllib.parse.urljoin(str(filename), ".") - else: - filename = pathlib.Path(filename).parents[0] - - return filename - - -def build_transforms_from_environment(env, dataset_config: DictConfig) -> Callable: - masking = dataset_config.transforms.masking # Masking func can be None - mask_func = None if masking is None else build_masking_function(**masking) - mri_transforms_func = functools.partial( - build_mri_transforms, - forward_operator=env.engine.forward_operator, - backward_operator=env.engine.backward_operator, - mask_func=mask_func, - ) - return mri_transforms_func(**dict_flatten(dict(remove_keys(dataset_config.transforms, "masking")))) # type: ignore - - -def build_training_datasets_from_environment( - env, - datasets_config: List[DictConfig], - lists_root: Optional[PathOrString] = None, - data_root: Optional[PathOrString] = None, - initial_images: Optional[Union[List[pathlib.Path], None]] = None, - initial_kspaces: Optional[Union[List[pathlib.Path], None]] = None, - pass_text_description: bool = True, - pass_dictionaries: Optional[Dict[str, Dict]] = None, -): - datasets = [] - for idx, dataset_config in enumerate(datasets_config): - if pass_text_description: - if not "text_description" in dataset_config: - dataset_config.text_description = f"ds{idx}" if len(datasets_config) > 1 else None - else: - dataset_config.text_description = None - if dataset_config.transforms.masking is None: # type: ignore - logger.info( - "Masking function set to None for %s.", - dataset_config.text_description, # type: ignore - ) - transforms = build_transforms_from_environment(env, dataset_config) - dataset_args = {"transforms": transforms, "dataset_config": dataset_config} - if initial_images is not None: - dataset_args.update({"initial_images": initial_images}) - if initial_kspaces is not None: - dataset_args.update({"initial_kspaces": initial_kspaces}) - if data_root is not None: - dataset_args.update({"data_root": data_root}) - filenames_filter = get_filenames_for_datasets_from_config(dataset_config, lists_root, data_root) - dataset_args.update({"filenames_filter": filenames_filter}) - if pass_dictionaries is not None: - dataset_args.update({"pass_dictionaries": pass_dictionaries}) - dataset = build_dataset_from_input(**dataset_args) - - logger.debug("Transforms %s / %s :\n%s", idx + 1, len(datasets_config), transforms) - datasets.append(dataset) - logger.info( - "Data size for %s (%s/%s): %s.", - dataset_config.text_description, # type: ignore - idx + 1, - len(datasets_config), - len(dataset), - ) - - return datasets - - -def setup_train( - run_name: str, - training_root: Union[pathlib.Path, None], - validation_root: Union[pathlib.Path, None], - base_directory: pathlib.Path, - cfg_filename: PathOrString, - force_validation: bool, - initialization_checkpoint: PathOrString, - initial_images: Optional[Union[List[pathlib.Path], None]], - initial_kspace: Optional[Union[List[pathlib.Path], None]], - noise: Optional[Union[List[pathlib.Path], None]], - device: str, - num_workers: int, - resume: bool, - machine_rank: int, - mixed_precision: bool, - debug: bool, -): - env = setup_training_environment( - run_name, - base_directory, - cfg_filename, - device, - machine_rank, - mixed_precision, - debug=debug, - ) - - # Trigger cudnn benchmark and remove the associated cache - torch.backends.cudnn.benchmark = True - torch.cuda.empty_cache() - - if initial_kspace is not None and initial_images is not None: - raise ValueError("Cannot both provide initial kspace or initial images.") - # Create training data - training_dataset_args = {"env": env, "datasets_config": env.cfg.training.datasets, "pass_text_description": True} - pass_dictionaries = {} - if noise is not None: - if not env.cfg.physics.use_noise_matrix: - raise ValueError("cfg.physics.use_noise_matrix is null, yet command line passed noise files.") - - noise = [read_json(fn) for fn in noise] - pass_dictionaries["loglikelihood_scaling"] = [ - parse_noise_dict(_, percentile=0.999, multiplier=env.cfg.physics.noise_matrix_scaling) for _ in noise - ] - training_dataset_args.update({"pass_dictionaries": pass_dictionaries}) - - if training_root is not None: - training_dataset_args.update({"data_root": training_root}) - # Get the lists_root. Assume now the given path is with respect to the config file. - lists_root = get_root_of_file(cfg_filename) - if lists_root is not None: - training_dataset_args.update({"lists_root": lists_root}) - if initial_images is not None: - training_dataset_args.update({"initial_images": initial_images[0]}) - if initial_kspace is not None: - training_dataset_args.update({"initial_kspaces": initial_kspace[0]}) - - # Build training datasets - training_datasets = build_training_datasets_from_environment(**training_dataset_args) - training_data_sizes = [len(_) for _ in training_datasets] - logger.info("Training data sizes: %s (sum=%s).", training_data_sizes, sum(training_data_sizes)) - - # Create validation data - if "validation" in env.cfg: - validation_dataset_args = { - "env": env, - "datasets_config": env.cfg.validation.datasets, - "pass_text_description": True, - } - if validation_root is not None: - validation_dataset_args.update({"data_root": validation_root}) - lists_root = get_root_of_file(cfg_filename) - if lists_root is not None: - validation_dataset_args.update({"lists_root": lists_root}) - if initial_images is not None: - validation_dataset_args.update({"initial_images": initial_images[1]}) - if initial_kspace is not None: - validation_dataset_args.update({"initial_kspaces": initial_kspace[1]}) - - # Build validation datasets - validation_data = build_training_datasets_from_environment(**validation_dataset_args) - else: - logger.info("No validation data.") - validation_data = None - - # Create the optimizers - logger.info("Building optimizers.") - optimizer_params = [{"params": env.engine.model.parameters()}] - for curr_model_name in env.engine.models: - # TODO(jt): Can get learning rate from the config per additional model too. - curr_learning_rate = env.cfg.training.lr - logger.info("Adding model parameters of %s with learning rate %s.", curr_model_name, curr_learning_rate) - optimizer_params.append( - { - "params": env.engine.models[curr_model_name].parameters(), - "lr": curr_learning_rate, - } - ) - - optimizer: torch.optim.Optimizer = str_to_class("torch.optim", env.cfg.training.optimizer)( # noqa - optimizer_params, - lr=env.cfg.training.lr, - weight_decay=env.cfg.training.weight_decay, - ) # noqa - - # Build the LR scheduler, we use a fixed LR schedule step size, no adaptive training schedule. - solver_steps = list( - range( - env.cfg.training.lr_step_size, - env.cfg.training.num_iterations, - env.cfg.training.lr_step_size, - ) - ) - lr_scheduler = WarmupMultiStepLR( - optimizer, - solver_steps, - env.cfg.training.lr_gamma, - warmup_factor=1 / 3.0, - warmup_iterations=env.cfg.training.lr_warmup_iter, - warmup_method="linear", - ) - - # Just to make sure. - torch.cuda.empty_cache() - - # Check the initialization checkpoint - if env.cfg.training.model_checkpoint: - if initialization_checkpoint: - logger.warning( - "`--initialization-checkpoint is set, and config has a set `training.model_checkpoint`: %s. " - "Will overwrite config variable with the command line: %s.", - env.cfg.training.model_checkpoint, - initialization_checkpoint, - ) - # Now overwrite this in the configuration, so the correct value is dumped. - env.cfg.training.model_checkpoint = str(initialization_checkpoint) - else: - initialization_checkpoint = env.cfg.training.model_checkpoint - - env.engine.train( - optimizer, - lr_scheduler, - training_datasets, - env.experiment_dir, - validation_datasets=validation_data, - resume=resume, - initialization=initialization_checkpoint, - start_with_validation=force_validation, - num_workers=num_workers, - ) - - -def train_from_argparse(args: argparse.Namespace): - # This sets MKL threads to 1. - # DataLoader can otherwise bring a lot of difficulties when computing CPU FFTs in the transforms. - torch.set_num_threads(1) - os.environ["OMP_NUM_THREADS"] = "1" - # Disable Tensorboard warnings. - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" - - if args.initialization_images is not None and args.initialization_kspace is not None: - sys.exit("--initialization-images and --initialization-kspace are mutually exclusive.") - check_train_val(args.initialization_images, "initialization-images") - check_train_val(args.initialization_kspace, "initialization-kspace") - check_train_val(args.noise, "noise") - - set_all_seeds(args.seed) - - run_name = args.name if args.name is not None else os.path.basename(args.cfg_file)[:-5] - - # TODO(jt): Duplicate params - launch( - setup_train, - args.num_machines, - args.num_gpus, - args.machine_rank, - args.dist_url, - run_name, - args.training_root, - args.validation_root, - args.experiment_dir, - args.cfg_file, - args.force_validation, - args.initialization_checkpoint, - args.initialization_images, - args.initialization_kspace, - args.noise, - args.device, - args.num_workers, - args.resume, - args.machine_rank, - args.mixed_precision, - args.debug, - ) +# coding=utf-8 +# Copyright (c) DIRECT Contributors +import argparse +import functools +import logging +import os +import pathlib +import sys +import urllib.parse +from collections import defaultdict +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from omegaconf import DictConfig + +from direct.cli.utils import check_train_val +from direct.common.subsample import build_masking_function +from direct.data.datasets import build_dataset_from_input +from direct.data.lr_scheduler import WarmupMultiStepLR +from direct.data.mri_transforms import build_mri_transforms +from direct.environment import setup_training_environment +from direct.launch import launch +from direct.types import PathOrString +from direct.utils import dict_flatten, remove_keys, set_all_seeds, str_to_class +from direct.utils.dataset import get_filenames_for_datasets_from_config +from direct.utils.io import check_is_valid_url, read_json + +logger = logging.getLogger(__name__) + + +def parse_noise_dict(noise_dict: dict, percentile: float = 1.0, multiplier: float = 1.0): + logger.info("Parsing noise dictionary...") + output: Dict = defaultdict(dict) + for filename in noise_dict: + data_per_volume = noise_dict[filename] + for slice_no in data_per_volume: + curr_data = data_per_volume[slice_no] + if percentile != 1.0: + lower_clip = np.percentile(curr_data, 100 * (1 - percentile)) + upper_clip = np.percentile(curr_data, 100 * percentile) + curr_data = np.clip(curr_data, lower_clip, upper_clip) + + output[filename][int(slice_no)] = ( + curr_data * multiplier + ) ** 2 # np.asarray(curr_data) * multiplier# (np.clip(curr_data, lower_clip, upper_clip) * multiplier) ** 2 + + return output + + +def get_root_of_file(filename: PathOrString): + """Get the root directory of the file or URL to file. + + Examples + -------- + >>> get_root_of_file('/mnt/archive/data.txt') + >>> /mnt/archive + >>> get_root_of_file('https://aiforoncology.nl/people') + >>> https://aiforoncology.nl/ + + Parameters + ---------- + filename: pathlib.Path or str + + Returns + ------- + pathlib.Path or str + """ + if check_is_valid_url(str(filename)): + filename = urllib.parse.urljoin(str(filename), ".") + else: + filename = pathlib.Path(filename).parents[0] + + return filename + + +def build_transforms_from_environment(env, dataset_config: DictConfig) -> Callable: + masking = dataset_config.transforms.masking # Masking func can be None + mask_func = None if masking is None else build_masking_function(**masking) + mri_transforms_func = functools.partial( + build_mri_transforms, + forward_operator=env.engine.forward_operator, + backward_operator=env.engine.backward_operator, + mask_func=mask_func, + ) + return mri_transforms_func(**dict_flatten(dict(remove_keys(dataset_config.transforms, "masking")))) # type: ignore + + +def build_training_datasets_from_environment( + env, + datasets_config: List[DictConfig], + lists_root: Optional[PathOrString] = None, + data_root: Optional[PathOrString] = None, + initial_images: Optional[Union[List[pathlib.Path], None]] = None, + initial_kspaces: Optional[Union[List[pathlib.Path], None]] = None, + pass_text_description: bool = True, + pass_dictionaries: Optional[Dict[str, Dict]] = None, +): + datasets = [] + for idx, dataset_config in enumerate(datasets_config): + if pass_text_description: + if not "text_description" in dataset_config: + dataset_config.text_description = f"ds{idx}" if len(datasets_config) > 1 else None + else: + dataset_config.text_description = None + if dataset_config.transforms.masking is None: # type: ignore + logger.info( + "Masking function set to None for %s.", + dataset_config.text_description, # type: ignore + ) + transforms = build_transforms_from_environment(env, dataset_config) + dataset_args = {"transforms": transforms, "dataset_config": dataset_config} + if initial_images is not None: + dataset_args.update({"initial_images": initial_images}) + if initial_kspaces is not None: + dataset_args.update({"initial_kspaces": initial_kspaces}) + if data_root is not None: + dataset_args.update({"data_root": data_root}) + filenames_filter = get_filenames_for_datasets_from_config(dataset_config, lists_root, data_root) + dataset_args.update({"filenames_filter": filenames_filter}) + if pass_dictionaries is not None: + dataset_args.update({"pass_dictionaries": pass_dictionaries}) + dataset = build_dataset_from_input(**dataset_args) + + logger.debug("Transforms %s / %s :\n%s", idx + 1, len(datasets_config), transforms) + datasets.append(dataset) + logger.info( + "Data size for %s (%s/%s): %s.", + dataset_config.text_description, # type: ignore + idx + 1, + len(datasets_config), + len(dataset), + ) + + return datasets + + +def setup_train( + run_name: str, + training_root: Union[pathlib.Path, None], + validation_root: Union[pathlib.Path, None], + base_directory: pathlib.Path, + cfg_filename: PathOrString, + force_validation: bool, + initialization_checkpoint: PathOrString, + initial_images: Optional[Union[List[pathlib.Path], None]], + initial_kspace: Optional[Union[List[pathlib.Path], None]], + noise: Optional[Union[List[pathlib.Path], None]], + device: str, + num_workers: int, + resume: bool, + machine_rank: int, + mixed_precision: bool, + debug: bool, +): + env = setup_training_environment( + run_name, + base_directory, + cfg_filename, + device, + machine_rank, + mixed_precision, + debug=debug, + ) + + # Trigger cudnn benchmark and remove the associated cache + torch.backends.cudnn.benchmark = True + torch.cuda.empty_cache() + + if initial_kspace is not None and initial_images is not None: + raise ValueError("Cannot both provide initial kspace or initial images.") + # Create training data + training_dataset_args = {"env": env, "datasets_config": env.cfg.training.datasets, "pass_text_description": True} + pass_dictionaries = {} + if noise is not None: + if not env.cfg.physics.use_noise_matrix: + raise ValueError("cfg.physics.use_noise_matrix is null, yet command line passed noise files.") + + noise = [read_json(fn) for fn in noise] + pass_dictionaries["loglikelihood_scaling"] = [ + parse_noise_dict(_, percentile=0.999, multiplier=env.cfg.physics.noise_matrix_scaling) for _ in noise + ] + training_dataset_args.update({"pass_dictionaries": pass_dictionaries}) + + if training_root is not None: + training_dataset_args.update({"data_root": training_root}) + # Get the lists_root. Assume now the given path is with respect to the config file. + lists_root = get_root_of_file(cfg_filename) + if lists_root is not None: + training_dataset_args.update({"lists_root": lists_root}) + if initial_images is not None: + training_dataset_args.update({"initial_images": initial_images[0]}) + if initial_kspace is not None: + training_dataset_args.update({"initial_kspaces": initial_kspace[0]}) + + # Build training datasets + training_datasets = build_training_datasets_from_environment(**training_dataset_args) + training_data_sizes = [len(_) for _ in training_datasets] + logger.info("Training data sizes: %s (sum=%s).", training_data_sizes, sum(training_data_sizes)) + + # Create validation data + if "validation" in env.cfg: + validation_dataset_args = { + "env": env, + "datasets_config": env.cfg.validation.datasets, + "pass_text_description": True, + } + if validation_root is not None: + validation_dataset_args.update({"data_root": validation_root}) + lists_root = get_root_of_file(cfg_filename) + if lists_root is not None: + validation_dataset_args.update({"lists_root": lists_root}) + if initial_images is not None: + validation_dataset_args.update({"initial_images": initial_images[1]}) + if initial_kspace is not None: + validation_dataset_args.update({"initial_kspaces": initial_kspace[1]}) + + # Build validation datasets + validation_data = build_training_datasets_from_environment(**validation_dataset_args) + else: + logger.info("No validation data.") + validation_data = None + + # Create the optimizers + logger.info("Building optimizers.") + optimizer_params = [{"params": env.engine.model.parameters()}] + for curr_model_name in env.engine.models: + # TODO(jt): Can get learning rate from the config per additional model too. + curr_learning_rate = env.cfg.training.lr + logger.info("Adding model parameters of %s with learning rate %s.", curr_model_name, curr_learning_rate) + optimizer_params.append( + { + "params": env.engine.models[curr_model_name].parameters(), + "lr": curr_learning_rate, + } + ) + + optimizer: torch.optim.Optimizer = str_to_class("torch.optim", env.cfg.training.optimizer)( # noqa + optimizer_params, + lr=env.cfg.training.lr, + weight_decay=env.cfg.training.weight_decay, + ) # noqa + + # Build the LR scheduler, we use a fixed LR schedule step size, no adaptive training schedule. + solver_steps = list( + range( + env.cfg.training.lr_step_size, + env.cfg.training.num_iterations, + env.cfg.training.lr_step_size, + ) + ) + lr_scheduler = WarmupMultiStepLR( + optimizer, + solver_steps, + env.cfg.training.lr_gamma, + warmup_factor=1 / 3.0, + warmup_iterations=env.cfg.training.lr_warmup_iter, + warmup_method="linear", + ) + + # Just to make sure. + torch.cuda.empty_cache() + + # Check the initialization checkpoint + if env.cfg.training.model_checkpoint: + if initialization_checkpoint: + logger.warning( + "`--initialization-checkpoint is set, and config has a set `training.model_checkpoint`: %s. " + "Will overwrite config variable with the command line: %s.", + env.cfg.training.model_checkpoint, + initialization_checkpoint, + ) + # Now overwrite this in the configuration, so the correct value is dumped. + env.cfg.training.model_checkpoint = str(initialization_checkpoint) + else: + initialization_checkpoint = env.cfg.training.model_checkpoint + + env.engine.train( + optimizer, + lr_scheduler, + training_datasets, + env.experiment_dir, + validation_datasets=validation_data, + resume=resume, + initialization=initialization_checkpoint, + start_with_validation=force_validation, + num_workers=num_workers, + ) + + +def train_from_argparse(args: argparse.Namespace): + # This sets MKL threads to 1. + # DataLoader can otherwise bring a lot of difficulties when computing CPU FFTs in the transforms. + torch.set_num_threads(1) + os.environ["OMP_NUM_THREADS"] = "1" + # Disable Tensorboard warnings. + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + + if args.initialization_images is not None and args.initialization_kspace is not None: + sys.exit("--initialization-images and --initialization-kspace are mutually exclusive.") + check_train_val(args.initialization_images, "initialization-images") + check_train_val(args.initialization_kspace, "initialization-kspace") + check_train_val(args.noise, "noise") + + set_all_seeds(args.seed) + + run_name = args.name if args.name is not None else os.path.basename(args.cfg_file)[:-5] + + # TODO(jt): Duplicate params + launch( + setup_train, + args.num_machines, + args.num_gpus, + args.machine_rank, + args.dist_url, + run_name, + args.training_root, + args.validation_root, + args.experiment_dir, + args.cfg_file, + args.force_validation, + args.initialization_checkpoint, + args.initialization_images, + args.initialization_kspace, + args.noise, + args.device, + args.num_workers, + args.resume, + args.machine_rank, + args.mixed_precision, + args.debug, + ) diff --git a/direct/types.py b/direct/types.py index a10d0d78..16040d7c 100644 --- a/direct/types.py +++ b/direct/types.py @@ -1,163 +1,166 @@ -# Copyright (c) DIRECT Contributors - -"""direct.types module.""" - -from __future__ import annotations - -import pathlib -from enum import Enum -from typing import NewType, Union - -import numpy as np -import torch -from omegaconf.omegaconf import DictConfig -from torch import nn as nn -from torch.cuda.amp import GradScaler - -DictOrDictConfig = Union[dict, DictConfig] -Number = Union[float, int] -PathOrString = Union[pathlib.Path, str] -FileOrUrl = NewType("FileOrUrl", PathOrString) -HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler] -TensorOrNone = Union[None, torch.Tensor] -TensorOrNdarray = Union[torch.Tensor, np.ndarray] - - -class DirectEnum(str, Enum): - """Type of any enumerator with allowed comparison to string invariant to cases.""" - - @classmethod - def from_str(cls, value: str) -> DirectEnum | None: - statuses = cls.__members__.keys() - for st in statuses: - if st.lower() == value.lower(): - return cls[st] - return None - - def __eq__(self, other: object) -> bool: - if isinstance(other, Enum): - _other = str(other.value) - else: - _other = str(other) - return bool(self.value.lower() == _other.lower()) - - def __hash__(self) -> int: - # re-enable hashtable so it can be used as a dict key or in a set - return hash(self.value.lower()) - - -class KspaceKey(DirectEnum): - ACS_KSPACE = "acs_kspace" - KSPACE = "kspace" - MASKED_KSPACE = "masked_kspace" - REFERENCE_KSPACE = "reference_kspace" - - -class TransformKey(DirectEnum): - # K-space keys - ACS_KSPACE = "acs_kspace" - KSPACE = "kspace" - MASKED_KSPACE = "masked_kspace" - # Mask keys - SAMPLING_MASK = "sampling_mask" - ACS_MASK = "acs_mask" - PADDING = "padding" - # Image keys - TARGET = "target" - # Other keys - SENSITIVITY_MAP = "sensitivity_map" - SCALING_FACTOR = "scaling_factor" - # Registration keys - DISPLACEMENT_FIELD = "displacement_field" - REFERENCE_IMAGE = "reference_image" - REFERENCE_KSPACE = "reference_kspace" - MOVING_IMAGE = "moving_image" - WARPED_IMAGE = "warped_image" - - -class MaskFuncMode(DirectEnum): - STATIC = "static" - DYNAMIC = "dynamic" - MULTISLICE = "multislice" - - -class IntegerListOrTupleStringMeta(type): - """Metaclass for the :class:`IntegerListOrTupleString` class. - - Returns - ------- - bool - True if the instance is a valid representation of IntegerListOrTupleString, False otherwise. - """ - - def __instancecheck__(cls, instance): - """Check if the given instance is a valid representation of an IntegerListOrTupleString. - - Parameters - ---------- - cls : type - The class being checked, i.e., IntegerListOrTupleStringMeta. - instance : object - The instance being checked. - - Returns - ------- - bool - True if the instance is a valid representation of IntegerListOrTupleString, False otherwise. - """ - if isinstance(instance, str): - try: - assert (instance.startswith("[") and instance.endswith("]")) or ( - instance.startswith("(") and instance.endswith(")") - ) - elements = instance.strip()[1:-1].split(",") - integers = [int(element) for element in elements] - return all(isinstance(num, int) for num in integers) - except (AssertionError, ValueError, AttributeError): - pass - return False - - -class IntegerListOrTupleString(metaclass=IntegerListOrTupleStringMeta): - """IntegerListOrTupleString class represents a list or tuple of integers based on a string representation. - - Examples - -------- - s1 = "[1, 2, 45, -1, 0]" - print(isinstance(s1, IntegerListOrTupleString)) # True - print(IntegerListOrTupleString(s1)) # [1, 2, 45, -1, 0] - print(type(IntegerListOrTupleString(s1))) # - print(type(IntegerListOrTupleString(s1)[0])) # - - s2 = "(10, -9, 20)" - print(isinstance(s2, IntegerListOrTupleString)) # True - print(IntegerListOrTupleString(s2)) # (10, -9, 20) - print(type(IntegerListOrTupleString(s2))) # - print(type(IntegerListOrTupleString(s2)[0])) # - - s3 = "[a, we, 2]" - print(isinstance(s3, IntegerListOrTupleString)) # False - - s4 = "(1, 2, 3]" - print(isinstance(s4 IntegerListOrTupleString)) # False - """ - - def __new__(cls, string): - """ - Create a new instance of IntegerListOrTupleString based on the given string representation. - - Parameters - ---------- - string : str - The string representation of the integer list or tuple. - - Returns - ------- - list or tuple - A new instance of IntegerListOrTupleString. - """ - list_or_tuple = list if string.startswith("[") else tuple - string = string.strip()[1:-1] # Remove outer brackets - elements = string.split(",") - integers = [int(element) for element in elements] - return list_or_tuple(integers) +# Copyright (c) DIRECT Contributors + +"""direct.types module.""" + +from __future__ import annotations + +import pathlib +from enum import Enum +from typing import NewType, Union + +import numpy as np +import torch +from omegaconf.omegaconf import DictConfig +from torch import nn as nn +from torch.cuda.amp import GradScaler + +DictOrDictConfig = Union[dict, DictConfig] +Number = Union[float, int] +PathOrString = Union[pathlib.Path, str] +FileOrUrl = NewType("FileOrUrl", PathOrString) +HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler] +TensorOrNone = Union[None, torch.Tensor] +TensorOrNdarray = Union[torch.Tensor, np.ndarray] + + +class DirectEnum(str, Enum): + """Type of any enumerator with allowed comparison to string invariant to cases.""" + + @classmethod + def from_str(cls, value: str) -> DirectEnum | None: + statuses = cls.__members__.keys() + for st in statuses: + if st.lower() == value.lower(): + return cls[st] + return None + + def __eq__(self, other: object) -> bool: + if isinstance(other, Enum): + _other = str(other.value) + else: + _other = str(other) + return bool(self.value.lower() == _other.lower()) + + def __hash__(self) -> int: + # re-enable hashtable so it can be used as a dict key or in a set + return hash(self.value.lower()) + + +class KspaceKey(DirectEnum): + ACS_KSPACE = "acs_kspace" + KSPACE = "kspace" + MASKED_KSPACE = "masked_kspace" + REFERENCE_KSPACE = "reference_kspace" + + +class TransformKey(DirectEnum): + # K-space keys + ACS_KSPACE = "acs_kspace" + KSPACE = "kspace" + MASKED_KSPACE = "masked_kspace" + # Mask keys + SAMPLING_MASK = "sampling_mask" + ACS_MASK = "acs_mask" + PADDING = "padding" + # Image keys + TARGET = "target" + # Other keys + SENSITIVITY_MAP = "sensitivity_map" + SCALING_FACTOR = "scaling_factor" + # Registration keys + DISPLACEMENT_FIELD = "displacement_field" + REFERENCE_IMAGE = "reference_image" + REFERENCE_KSPACE = "reference_kspace" + MOVING_IMAGE = "moving_image" + WARPED_IMAGE = "warped_image" + # Other keys + ACCELERATION = "acceleration" + CENTER_FRACTION = "center_fraction" + + +class MaskFuncMode(DirectEnum): + STATIC = "static" + DYNAMIC = "dynamic" + MULTISLICE = "multislice" + + +class IntegerListOrTupleStringMeta(type): + """Metaclass for the :class:`IntegerListOrTupleString` class. + + Returns + ------- + bool + True if the instance is a valid representation of IntegerListOrTupleString, False otherwise. + """ + + def __instancecheck__(cls, instance): + """Check if the given instance is a valid representation of an IntegerListOrTupleString. + + Parameters + ---------- + cls : type + The class being checked, i.e., IntegerListOrTupleStringMeta. + instance : object + The instance being checked. + + Returns + ------- + bool + True if the instance is a valid representation of IntegerListOrTupleString, False otherwise. + """ + if isinstance(instance, str): + try: + assert (instance.startswith("[") and instance.endswith("]")) or ( + instance.startswith("(") and instance.endswith(")") + ) + elements = instance.strip()[1:-1].split(",") + integers = [int(element) for element in elements] + return all(isinstance(num, int) for num in integers) + except (AssertionError, ValueError, AttributeError): + pass + return False + + +class IntegerListOrTupleString(metaclass=IntegerListOrTupleStringMeta): + """IntegerListOrTupleString class represents a list or tuple of integers based on a string representation. + + Examples + -------- + s1 = "[1, 2, 45, -1, 0]" + print(isinstance(s1, IntegerListOrTupleString)) # True + print(IntegerListOrTupleString(s1)) # [1, 2, 45, -1, 0] + print(type(IntegerListOrTupleString(s1))) # + print(type(IntegerListOrTupleString(s1)[0])) # + + s2 = "(10, -9, 20)" + print(isinstance(s2, IntegerListOrTupleString)) # True + print(IntegerListOrTupleString(s2)) # (10, -9, 20) + print(type(IntegerListOrTupleString(s2))) # + print(type(IntegerListOrTupleString(s2)[0])) # + + s3 = "[a, we, 2]" + print(isinstance(s3, IntegerListOrTupleString)) # False + + s4 = "(1, 2, 3]" + print(isinstance(s4 IntegerListOrTupleString)) # False + """ + + def __new__(cls, string): + """ + Create a new instance of IntegerListOrTupleString based on the given string representation. + + Parameters + ---------- + string : str + The string representation of the integer list or tuple. + + Returns + ------- + list or tuple + A new instance of IntegerListOrTupleString. + """ + list_or_tuple = list if string.startswith("[") else tuple + string = string.strip()[1:-1] # Remove outer brackets + elements = string.split(",") + integers = [int(element) for element in elements] + return list_or_tuple(integers) diff --git a/direct/utils/writers.py b/direct/utils/writers.py index d5720c10..d1fcd42f 100644 --- a/direct/utils/writers.py +++ b/direct/utils/writers.py @@ -48,13 +48,22 @@ def write_output_to_h5( with open(output_directory / "metrics_inference.json", "w") as f: f.write(json.dumps(metrics, indent=4)) - for idx, (volume, sampling_mask, filename) in enumerate(output[0]): + for idx, (data, sampling_mask, filename) in enumerate(output[0]): if isinstance(filename, pathlib.PosixPath): filename = filename.name logger.info(f"({idx + 1}/{len(output[0])}): Writing {output_directory / filename}...") + if isinstance(data, tuple): + volume, registration_volume, displacement_field = data + else: + volume = data + registration_volume = None + reconstruction = volume.numpy()[:, 0, ...].astype(np.float32) + if registration_volume is not None: + registration_volume = registration_volume.numpy()[:, 0, ...].astype(np.float32) + displacement_field = displacement_field.numpy().astype(np.float32) if sampling_mask is not None: sampling_mask = sampling_mask.numpy()[:, 0, ...].astype(np.float32) @@ -66,3 +75,6 @@ def write_output_to_h5( f.create_dataset(output_key, data=reconstruction) if sampling_mask is not None: f.create_dataset("sampling_mask", data=sampling_mask) + if registration_volume is not None: + f.create_dataset("registration_volume", data=registration_volume) + f.create_dataset("displacement_field", data=displacement_field) \ No newline at end of file