From 520edd4fa5e5a6cee6dc161e3b375152d02ac2c1 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Tue, 18 Oct 2022 18:58:34 +0300 Subject: [PATCH] New loss functions, refactored engines to only implement forward_method, quality fixes (#226) (Closes #225) * New loss functions (`NMSE`, `NRMSE`, `NMAE`, `SobelGradL1Loss`, `SobelGradL2Loss`) * `mri_models` performs `_do_iteration method`, child engines perform `forward_function` which returns output_image and/or output_kspace * Changes/Additions in `mri_transforms` * Padding computed as a Tensor with `ComputePadding` transform (this is helpful when cropping image and tranforming to kspace) * `ApplyPadding` transform * ComputeImage transform choices of mod output or not * `RenameKeys` transform * `Normalize` split to `ComputeScalingFactor` and `Normalize` * Some quality changes * Some documentation changes --- direct/common/subsample.py | 17 +- direct/data/datasets.py | 3 + direct/data/datasets_config.py | 19 +- direct/data/h5_data.py | 2 + direct/data/mri_transforms.py | 412 ++++++++++++---- direct/data/transforms.py | 26 +- direct/engine.py | 15 +- direct/exceptions.py | 10 + direct/functionals/__init__.py | 4 + direct/functionals/grad.py | 203 ++++++++ direct/functionals/nmae.py | 43 ++ direct/functionals/nmse.py | 80 ++++ direct/nn/cirim/cirim.py | 12 +- direct/nn/cirim/cirim_engine.py | 26 +- direct/nn/jointicnet/jointicnet_engine.py | 83 +--- direct/nn/kikinet/kikinet_engine.py | 87 +--- direct/nn/lpd/lpd_engine.py | 82 +--- direct/nn/mri_models.py | 447 ++++++++++++++---- .../multidomainnet/multidomainnet_engine.py | 82 +--- .../recurrentvarnet/recurrentvarnet_engine.py | 89 +--- direct/nn/rim/config.py | 19 - direct/nn/rim/rim_engine.py | 83 ++-- direct/nn/unet/config.py | 9 + direct/nn/unet/unet_2d.py | 2 +- direct/nn/unet/unet_engine.py | 85 +--- direct/nn/varnet/varnet_engine.py | 86 +--- direct/nn/xpdnet/xpdnet_engine.py | 85 +--- direct/types.py | 1 + setup.py | 2 +- tests/tests_common/test_subsample.py | 12 +- tests/tests_data/test_mri_transforms.py | 111 +++-- tests/tests_data/test_transforms.py | 15 + tests/tests_functionals/test_gradloss.py | 33 ++ tests/tests_functionals/test_nmae.py | 39 ++ tests/tests_functionals/test_nmse.py | 12 +- tests/tests_nn/test_jointicnet_engine.py | 2 +- tests/tests_nn/test_recurrentvarnet_engine.py | 34 +- 37 files changed, 1446 insertions(+), 926 deletions(-) create mode 100644 direct/functionals/grad.py create mode 100644 direct/functionals/nmae.py create mode 100644 direct/functionals/nmse.py create mode 100644 tests/tests_functionals/test_gradloss.py create mode 100644 tests/tests_functionals/test_nmae.py diff --git a/direct/common/subsample.py b/direct/common/subsample.py index 9b5988b0..eabc93eb 100644 --- a/direct/common/subsample.py +++ b/direct/common/subsample.py @@ -135,7 +135,7 @@ def __init__( ) @staticmethod - def center_mask_func(num_cols, num_low_freqs): + def center_mask_func(num_cols: int, num_low_freqs: int) -> np.ndarray: # create the mask mask = np.zeros(num_cols, dtype=bool) @@ -415,7 +415,7 @@ def mask_func( mask_negative = np.flip(mask_negative) mask = np.fft.fftshift(np.concatenate((mask_positive, mask_negative))) - mask = mask | acs_mask + mask = np.logical_or(mask, acs_mask) return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) @@ -781,11 +781,12 @@ class VariableDensityPoissonMaskFunc(BaseMaskFunc): def __init__( self, accelerations: Union[List[Number], Tuple[Number, ...]], - center_scales: Union[List[float], Tuple[float, ...]], + center_fractions: Union[List[float], Tuple[float, ...]], crop_corner: Optional[bool] = False, max_attempts: Optional[int] = 10, tol: Optional[float] = 0.2, slopes: Optional[Union[List[float], Tuple[float, ...]]] = None, + **kwargs, ): """Inits :class:`VariableDensityPoissonMaskFunc`. @@ -793,8 +794,8 @@ def __init__( ---------- accelerations: list or tuple of positive numbers Amount of under-sampling. - center_scales: list or tuple of floats - Must have the same lenght as `accelerations`. Amount of center fully-sampling. + center_fractions: list or tuple of floats + Must have the same length as `accelerations`. Amount of center fully-sampling. For center_scale='r', then a centered disk area with radius equal to :math:`R = \sqrt{{n_r}^2 + {n_c}^2} \times r` will be fully sampled, where :math:`n_r` and :math:`n_c` denote the input shape. @@ -810,7 +811,7 @@ def __init__( """ super().__init__( accelerations=accelerations, - center_fractions=center_scales, + center_fractions=center_fractions, uniform_range=False, ) self.crop_corner = crop_corner @@ -864,9 +865,9 @@ def mask_func( if return_acs: return torch.from_numpy( self.centered_disk_mask((num_rows, num_cols), center_fraction)[np.newaxis, ..., np.newaxis] - ) + ).bool() mask = self.poisson(num_rows, num_cols, center_fraction, acceleration, cython_seed) - return torch.from_numpy(mask[np.newaxis, ..., np.newaxis]) + return torch.from_numpy(mask[np.newaxis, ..., np.newaxis]).bool() def poisson( self, diff --git a/direct/data/datasets.py b/direct/data/datasets.py index 23f28732..9d2e01e8 100644 --- a/direct/data/datasets.py +++ b/direct/data/datasets.py @@ -451,6 +451,9 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: num_z = kspace.shape[1] kspace[:, int(np.ceil(num_z * self.sampling_rate_slice_encode)) :, :] = 0.0 + 0.0 * 1j + sample["padding_left"] = 0 + sample["padding_right"] = np.all(np.abs(kspace).sum(-1) == 0, axis=0).nonzero()[0][0] + # Downstream code expects the coils to be at the first axis. sample["kspace"] = np.ascontiguousarray(kspace.transpose(2, 0, 1)) diff --git a/direct/data/datasets_config.py b/direct/data/datasets_config.py index d39f2720..f780df9b 100644 --- a/direct/data/datasets_config.py +++ b/direct/data/datasets_config.py @@ -3,7 +3,7 @@ """Classes holding the typed configurations for the datasets.""" -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import List, Optional, Tuple from omegaconf import MISSING @@ -14,15 +14,20 @@ @dataclass class TransformsConfig(BaseConfig): - crop: Optional[Tuple[int, int]] = field(default_factory=lambda: (320, 320)) - crop_type: str = "uniform" - estimate_sensitivity_maps: bool = False + masking: MaskingConfig = MaskingConfig() + crop: Optional[Tuple[int, int]] = None + crop_type: Optional[str] = "uniform" + image_center_crop: bool = False + padding_eps: float = 0.001 + estimate_sensitivity_maps: bool = True estimate_body_coil_image: bool = False sensitivity_maps_gaussian: Optional[float] = 0.7 - image_center_crop: bool = True + delete_acs_mask: bool = True + delete_kspace: bool = True + image_recon_type: str = "rss" pad_coils: Optional[int] = None - scaling_key: Optional[str] = None - masking: MaskingConfig = MaskingConfig() + scaling_key: Optional[str] = "masked_kspace" + use_seed: bool = True @dataclass diff --git a/direct/data/h5_data.py b/direct/data/h5_data.py index 4b667cab..91306a7e 100644 --- a/direct/data/h5_data.py +++ b/direct/data/h5_data.py @@ -114,6 +114,8 @@ def __init__( self.logger.info("Attempting to load %s filenames.", len(filenames_filter)) filenames = filenames_filter + filenames = [pathlib.Path(_) for _ in filenames] + if regex_filter: filenames = [_ for _ in filenames if re.match(regex_filter, str(_))] diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index 2e9f3ae9..5b5da3a9 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -1,16 +1,17 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors - import functools import logging import warnings +from enum import Enum from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch from direct.data import transforms as T +from direct.exceptions import ItemNotFoundException from direct.utils import DirectModule, DirectTransform from direct.utils.asserts import assert_complex @@ -133,19 +134,8 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: sampling_mask = self.mask_func(shape=shape, seed=seed, return_acs=False) - if sample.get("padding_left", 0) > 0 or sample.get("padding_right", 0) > 0: - - if sample["kspace"].shape[2] != shape[-2]: - raise ValueError( - "Currently only support for the `width` axis to be at the 2th position when padding. " - + "When padding in left or right is present, you cannot crop in the phase-encoding direction!" - ) - - padding_left = sample["padding_left"] - padding_right = sample["padding_right"] - - sampling_mask[:, :, :padding_left, :] = 0 - sampling_mask[:, :, padding_right:, :] = 0 + if "padding" in sample: + sampling_mask = T.apply_padding(sampling_mask, sample["padding"]) # Shape (1, [slice], height, width, 1) sample["sampling_mask"] = sampling_mask @@ -160,37 +150,60 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: class ApplyMask(DirectModule): """Data Transformer for training MRI reconstruction models. - Masks the k-space using a sampling mask. + Masks the input k-space (with key `input_kspace_key`) using a sampling mask with key `sampling_mask_key` onto + a new masked k-space with key `target_kspace_key`. """ - def __init__(self) -> None: - """Inits :class:`ApplyMask`.""" + def __init__( + self, + sampling_mask_key: str = "sampling_mask", + input_kspace_key: str = "kspace", + target_kspace_key: str = "masked_kspace", + ) -> None: + """Inits :class:`ApplyMask`. + + Parameters + ---------- + sampling_mask_key: str + Default: "sampling_mask". + input_kspace_key: str + Default: "kspace". + target_kspace_key: str + Default "masked_kspace". + """ super().__init__() self.logger = logging.getLogger(type(self).__name__) + self.sampling_mask_key = sampling_mask_key + self.input_kspace_key = input_kspace_key + self.target_kspace_key = target_kspace_key + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Calls :class:`ApplyMask`. - This assumes that a `sampling_mask` is present in the sample. + Applies mask with key `sampling_mask_key` onto kspace `input_kspace_key`. Result is stored as a tensor with + key `target_kspace_key`. Parameters ---------- sample: Dict[str, Any] - Dict sample containing key `kspace`. + Dict sample containing keys `sampling_mask_key` and `input_kspace_key`. Returns ------- Dict[str, Any] - Sample with new key `masked_kspace`. + Sample with (new) key `target_kspace_key`. """ - kspace = sample["kspace"] + if self.input_kspace_key not in sample: + raise ValueError(f"Key {self.input_kspace_key} corresponding to `input_kspace_key` not found in sample.") + input_kspace = sample[self.input_kspace_key] - assert "sampling_mask" in sample, "Key 'sampling_mask' not found in sample." - sampling_mask = sample["sampling_mask"] - - masked_kspace, _ = T.apply_mask(kspace, sampling_mask) - sample["masked_kspace"] = masked_kspace + if self.sampling_mask_key not in sample: + raise ValueError(f"Key {self.sampling_mask_key} corresponding to `sampling_mask_key` not found in sample.") + sampling_mask = sample[self.sampling_mask_key] + target_kspace, _ = T.apply_mask(input_kspace, sampling_mask) + sample[self.target_kspace_key] = target_kspace return sample @@ -298,6 +311,109 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: return sample +class ComputeZeroPadding(DirectModule): + r"""Computes zero padding present in multi-coil kspace input. + + Zero-padding is computed from multi-coil kspace with no signal contribution, i.e. its magnitude + is really close to zero: + + .. math :: + + \text{padding} = \sum_{i=1}^{n_c} |y_i| < \frac{1}{n_x \cdot n_y} + \sum_{j=1}^{n_x \cdot n_y} \big\{\sum_{i=1}^{n_c} |y_i|\big\} * \epsilon. + """ + + def __init__(self, kspace_key: str = "kspace", padding_key: str = "padding", eps: float = 0.0001) -> None: + """Inits :class:`ComputeZeroPadding`. + + Parameters + ---------- + kspace_key: str + K-space key. Default: "kspace". + padding_key: str + Target key. Default: "padding". + eps: float + Epsilon to multiply sum of signals. If really high, probably no padding will be produced. Default: 0.0001. + """ + super().__init__() + self.kspace_key = kspace_key + self.padding_key = padding_key + self.eps = eps + + def __call__(self, sample: Dict[str, Any], coil_dim: int = 0) -> Dict[str, Any]: + """Updates sample with a key `padding_key` with value a binary tensor. + + Non-zero entries indicate samples in kspace with key `kspace_key` which have minor contribution, i.e. padding. + + Parameters + ---------- + sample : Dict[str, Any] + Dict sample containing key `kspace_key`. + coil_dim : int + Coil dimension. Default: 0. + + Returns + ------- + sample : Dict[str, Any] + Dict sample containing key `padding_key`. + """ + + kspace = T.modulus(sample[self.kspace_key]).sum(coil_dim) + padding = (kspace < torch.mean(kspace) * self.eps).to(kspace.device).unsqueeze(coil_dim).unsqueeze(-1) + + sample[self.padding_key] = padding + + return sample + + +class ApplyZeroPadding(DirectModule): + """Applies zero padding present in multi-coil kspace input.""" + + def __init__(self, kspace_key: str = "kspace", padding_key: str = "padding") -> None: + """Inits :class:`ApplyZeroPadding`. + + Parameters + ---------- + kspace_key: str + K-space key. Default: "kspace". + padding_key: str + Target key. Default: "padding". + """ + super().__init__() + self.kspace_key = kspace_key + self.padding_key = padding_key + + def __call__(self, sample: Dict[str, Any], coil_dim: int = 0) -> Dict[str, Any]: + """Applies zero padding on `kspace_key` with value a binary tensor. + + Parameters + ---------- + sample : Dict[str, Any] + Dict sample containing key `kspace_key`. + coil_dim : int + Coil dimension. Default: 0. + + Returns + ------- + sample : Dict[str, Any] + Dict sample containing key `padding_key`. + """ + + sample[self.kspace_key] = T.apply_padding(sample[self.kspace_key], sample[self.padding_key]) + + return sample + + +class ReconstructionType(str, Enum): + """Reconstruction method for :class:`ComputeImage` transform.""" + + rss = "rss" + complex = "complex" + complex_mod = "complex_mod" + sense = "sense" + sense_mod = "sense_mod" + + class ComputeImage(DirectModule): """Compute Image transform. @@ -305,7 +421,11 @@ class ComputeImage(DirectModule): """ def __init__( - self, kspace_key: str, target_key: str, backward_operator: Callable, type_reconstruction: str = "complex" + self, + kspace_key: str, + target_key: str, + backward_operator: Callable, + type_reconstruction: ReconstructionType.rss, ) -> None: """Inits :class:`ComputeImage`. @@ -318,23 +438,16 @@ def __init__( backward_operator: callable The backward operator, e.g. some form of inverse FFT (centered or uncentered). type_reconstruction: str - Type of reconstruction. Can be 'complex', 'sense' or 'rss'. Default: 'complex'. + Type of reconstruction. Can be "complex", "complex_mod", "sense", "sense_mod" or "rss". Default: "complex". """ super().__init__() self.backward_operator = backward_operator self.kspace_key = kspace_key self.target_key = target_key - self.type_reconstruction = type_reconstruction - if type_reconstruction.lower() not in ["complex", "sense", "rss"]: - raise ValueError( - f"Only `complex`, `rss` and `sense` are possible choices for `reconstruction_type`. " - f"Got {self.type_reconstruction}." - ) - def __call__( - self, sample: Dict[str, Any], coil_dim: int = 0, spatial_dims: Tuple[int, int] = (1, 2) + self, sample: Dict[str, Any], coil_dim: int = 0, spatial_dims: Tuple[int, int] = (1, 2), complex_dim: int = -1 ) -> Dict[str, Any]: """Calls :class:`ComputeImage`. @@ -346,27 +459,31 @@ def __call__( Coil dimension. Default: 0. spatial_dims: (int, int) Spatial dimensions corresponding to (height, width). Default: (1, 2). + complex_dim: int + Complex dimension. Used if `type_reconstruction` is either "complex_mod" or "sense_mod" Default: -1. Returns ---------- sample: dict - Contains key target_key with value a torch.Tensor of shape (*spatial_dims) or (*spatial_dims) if - type_reconstruction is 'rss'. + Contains key target_key with value a torch.Tensor of shape (*spatial_dims) if `type_reconstruction` is + "rss", "complex_mod" or "sense_mod", and of shape(*spatial_dims, complex_dim=2) otherwise. """ kspace_data = sample[self.kspace_key] # Get complex-valued data solution image = self.backward_operator(kspace_data, dim=spatial_dims) - if self.type_reconstruction == "complex": + if self.type_reconstruction in [ReconstructionType.complex, ReconstructionType.complex_mod]: sample[self.target_key] = image.sum(coil_dim) - elif self.type_reconstruction.lower() == "rss": + elif self.type_reconstruction == ReconstructionType.rss: sample[self.target_key] = T.root_sum_of_squares(image, dim=coil_dim) - elif self.type_reconstruction == "sense": + else: if "sensitivity_map" not in sample: - raise ValueError("Sensitivity map is required for SENSE reconstruction.") + raise ItemNotFoundException("sensitivity map", "Sensitivity map is required for SENSE reconstruction.") sample[self.target_key] = T.complex_multiplication(T.conjugate(sample["sensitivity_map"]), image).sum( coil_dim ) + if self.type_reconstruction in [ReconstructionType.complex_mod, ReconstructionType.sense_mod]: + sample[self.target_key] = T.modulus(sample[self.target_key], complex_dim) return sample @@ -570,6 +687,43 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: return sample +class RenameKeys(DirectModule): + """Rename keys from the sample if present.""" + + def __init__(self, old_keys: List[str], new_keys: List[str]): + """Inits :class:`RenameKeys`. + + Parameters + ---------- + old_keys: List[str] + Key(s) to rename. + new_keys: List[str] + Key(s) to replace old keys. + """ + super().__init__() + self.old_keys = old_keys + self.new_keys = new_keys + + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Calls :class:`RenameKeys`. + + Parameters + ---------- + sample: Dict[str, Any] + Dictionary to look for keys and rename them. + + Returns + ------- + Dict[str, Any] + Dictionary with renamed specified keys. + """ + for old_key, new_key in zip(self.old_keys, self.new_keys): + if old_key in sample: + sample[new_key] = sample.pop(old_key) + + return sample + + class PadCoilDimension(DirectModule): """Pad the coils by zeros to a given number of coils. @@ -633,45 +787,47 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: return sample -class Normalize(DirectModule): - """Normalize the input data either to the percentile or to the maximum.""" +class ComputeScalingFactor(DirectModule): + """Calculates scaling factor. - def __init__(self, normalize_key: str = "masked_kspace", percentile: Union[None, float] = 0.99): - """Inits :class:`Normalize`. + Scaling factor is for the input data based on either to the percentile or to the maximum of `normalize_key`. + """ + + def __init__( + self, + normalize_key: Union[None, str] = "masked_kspace", + percentile: Union[None, float] = 0.99, + scaling_factor_key: str = "scaling_factor", + ): + """Inits :class:`ComputeScalingFactor`. Parameters ---------- - normalize_key: str + normalize_key : str or None Key name to compute the data for. If the maximum has to be computed on the ACS, ensure the reconstruction on the ACS is available (typically `body_coil_image`). Default: "masked_kspace". - percentile: float or None + percentile : float or None Rescale data with the given percentile. If None, the division is done by the maximum. Default: 0.99. + scaling_factor_key : str + Name of how the scaling factor will be stored. Default: 'scaling_factor'. """ super().__init__() self.normalize_key = normalize_key self.percentile = percentile - - self._other_keys = [ - "masked_kspace", - "target", - "kspace", - "body_coil_image", # sensitivity_map does not require normalization. - "initial_image", - "initial_kspace", - ] + self.scaling_factor_key = scaling_factor_key def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Calls :class:`Normalize`. + """Calls :class:`ComputeScalingFactor`. Parameters ---------- sample: Dict[str, Any] - Sample to normalize with key "masked_kspace". + Sample with key `normalize_key` to compute scaling_factor. Returns ------- sample: Dict[str, Any] - Sample with normalized values if their respective key is not in `self._other_keys`. + Sample with key `scaling_factor_key`. """ if self.normalize_key == "scaling_factor": # This is a real-valued given number scaling_factor = sample["scaling_factor"] @@ -679,7 +835,6 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: scaling_factor = 1.0 else: data = sample[self.normalize_key] - # Compute the maximum and scale the input if self.percentile: tview = -1.0 * T.modulus(data).view(-1) @@ -688,15 +843,60 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: else: scaling_factor = T.modulus(data).max() + sample[self.scaling_factor_key] = scaling_factor + return sample + + +class Normalize(DirectModule): + """Normalize the input data.""" + + def __init__(self, scaling_factor_key: str = "scaling_factor", keys_to_normalize: Optional[List[str]] = None): + """Inits :class:`Normalize`. + + Parameters + ---------- + scaling_factor_key : str + Name of scaling factor key expected in sample. Default: 'scaling_factor'. + """ + super().__init__() + self.scaling_factor_key = scaling_factor_key + + self.keys_to_normalize = ( + [ + "masked_kspace", + "target", + "kspace", + "body_coil_image", # sensitivity_map does not require normalization. + "initial_image", + "initial_kspace", + ] + if keys_to_normalize is None + else keys_to_normalize + ) + + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Calls :class:`Normalize`. + + Parameters + ---------- + sample: Dict[str, Any] + Sample to normalize. + + Returns + ------- + sample: Dict[str, Any] + Sample with normalized values if their respective key is in `keys_to_normalize` and key + `scaling_factor_key` exists in sample. + """ + scaling_factor = sample.get(self.scaling_factor_key, None) # Normalize data - if self.normalize_key: + if scaling_factor: for key in sample.keys(): - if key != self.normalize_key and key not in self._other_keys: + if key not in self.keys_to_normalize: continue sample[key] = sample[key] / scaling_factor - sample["scaling_diff"] = 0.0 - sample["scaling_factor"] = scaling_factor + sample["scaling_diff"] = 0.0 return sample @@ -829,14 +1029,18 @@ def build_mri_transforms( forward_operator: Callable, backward_operator: Callable, mask_func: Optional[Callable], - crop: Optional[Tuple[int]] = None, + crop: Optional[Union[Tuple[int, int], str]] = None, crop_type: Optional[str] = "uniform", - image_center_crop: bool = False, + image_center_crop: bool = True, + padding_eps: float = 0.0001, estimate_sensitivity_maps: bool = True, estimate_body_coil_image: bool = False, sensitivity_maps_gaussian: Optional[float] = None, + delete_acs_mask: bool = True, + delete_kspace: bool = True, + image_recon_type: str = "rss", pad_coils: Optional[int] = None, - scaling_key: str = "scaling_factor", + scaling_key: str = "masked_kspace", use_seed: bool = True, ) -> object: """Build transforms for MRI. @@ -851,31 +1055,48 @@ def build_mri_transforms( Parameters ---------- - forward_operator: Callable + forward_operator : Callable The forward operator, e.g. some form of FFT (centered or uncentered). - backward_operator: Callable + backward_operator : Callable The backward operator, e.g. some form of inverse FFT (centered or uncentered). - mask_func: Callable or None + mask_func : Callable or None A function which creates a sampling mask of the appropriate shape. - crop: Tuple[int] or None - crop_type: Optional[str] - Type of cropping, either "gaussian" or "uniform". Default: "uniform". - image_center_crop: bool - estimate_sensitivity_maps: bool - estimate_body_coil_image: bool - sensitivity_maps_gaussian: float + crop : Tuple[int, int] or str, Optional + If not None, this will transform the "kspace" to an image domain, crop it, and transform it back. + If a tuple of integers is given then it will crop the backprojected kspace to that size. If + "reconstruction_size" is given, then it will crop the backprojected kspace according to it, but + a key "reconstruction_size" must be present in the sample. Default: None. + crop_type : Optional[str] + Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform". + image_center_crop : bool + If True the backprojected kspace will be cropped around the center, otherwise randomly. + This will be ignored if `crop` is None. Default: True. + padding_eps: float + Padding epsilon. Default: 0.0001. + estimate_sensitivity_maps : bool + Estimate sensitivity maps using the acs region. Default: True. + estimate_body_coil_image : bool + Estimate body coil image. Default: False. + sensitivity_maps_gaussian : float Optional sigma for gaussian weighting of sensitivity map. - pad_coils: int + delete_acs_mask : bool + If True will delete key `acs_mask`. Default: True. + delete_kspace : bool + If True will delete key `kspace` (fully sampled k-space). Default: True. + image_recon_type : str + Type to reconstruct target image. Default: "rss". + pad_coils : int Number of coils to pad data to. - scaling_key: str - use_seed: bool + scaling_key : str + Key in sample to scale scalable items in sample. Default: "masked_kspace". + use_seed : bool If true, a pseudo-random number based on the filename is computed so that every slice of the volume get the same mask every time. Default: True. Returns ------- object: Callable - A transformation object. + An MRI transformation object. """ # TODO: Use seed @@ -893,12 +1114,14 @@ def build_mri_transforms( ] if mask_func: mri_transforms += [ + ComputeZeroPadding("kspace", "padding", padding_eps), + ApplyZeroPadding("kspace", "padding"), CreateSamplingMask( mask_func, shape=(None if (isinstance(crop, str)) else crop), use_seed=use_seed, return_acs=estimate_sensitivity_maps, - ) + ), ] mri_transforms += [ @@ -907,27 +1130,32 @@ def build_mri_transforms( backward_operator=backward_operator, type_of_map="unit" if not estimate_sensitivity_maps else "rss_estimate", gaussian_sigma=sensitivity_maps_gaussian, - ), - DeleteKeys(keys=["acs_mask"]), + ) + ] + + if delete_acs_mask: + mri_transforms += [DeleteKeys(keys=["acs_mask"])] + + mri_transforms += [ ComputeImage( kspace_key="kspace", target_key="target", backward_operator=backward_operator, - type_reconstruction="rss", + type_reconstruction=image_recon_type, ), - ApplyMask(), + ApplyMask(sampling_mask_key="sampling_mask", input_kspace_key="kspace", target_kspace_key="masked_kspace"), ] if estimate_body_coil_image and mask_func is not None: mri_transforms.append(EstimateBodyCoilImage(mask_func, backward_operator=backward_operator, use_seed=use_seed)) mri_transforms += [ - Normalize( - normalize_key=scaling_key, - percentile=0.99, - ), + ComputeScalingFactor(normalize_key=scaling_key, percentile=0.99, scaling_factor_key="scaling_factor"), + Normalize(), PadCoilDimension(pad_coils=pad_coils, key="masked_kspace"), PadCoilDimension(pad_coils=pad_coils, key="sensitivity_map"), - DeleteKeys(keys=["kspace"]), ] + if delete_kspace: + mri_transforms += [DeleteKeys(keys=["kspace"])] + return Compose(mri_transforms) diff --git a/direct/data/transforms.py b/direct/data/transforms.py index 57d5cef4..6f9ea586 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -489,6 +489,30 @@ def conjugate(data: torch.Tensor) -> torch.Tensor: return data +def apply_padding( + data: torch.Tensor, + padding: Union[None, torch.Tensor], +) -> torch.Tensor: + """Applies zero padding to `data`. + + Parameters + ---------- + data : torch.Tensor + Batched or not input to be padded of shape (`batch`, *, `height`, `width`, *). + padding : torch.Tensor or None + Binary tensor of shape (`batch`, 1, `height`, `width`, 1). Entries in `padding` with non-zero value + point to samples in `data` that will be zero-padded. If None, `data` will be returned. + + Returns + ------- + data : torch.Tensor + Padded data. + """ + if padding is None: + return data + return torch.where(padding == 1, torch.tensor([0.0], dtype=data.dtype, device=data.device), data) + + def apply_mask( kspace: torch.Tensor, mask_func: Union[Callable, torch.Tensor], @@ -523,7 +547,7 @@ def apply_mask( else: mask = mask_func - masked_kspace = torch.where(mask == 0, torch.tensor([0.0], dtype=kspace.dtype), kspace) + masked_kspace = torch.where(mask == 0, torch.tensor([0.0], dtype=kspace.dtype, device=kspace.device), kspace) if not return_mask: return masked_kspace diff --git a/direct/engine.py b/direct/engine.py index 46e1ee9a..68c5610c 100644 --- a/direct/engine.py +++ b/direct/engine.py @@ -37,7 +37,6 @@ from direct.types import PathOrString from direct.utils import ( communication, - evaluate_dict, normalize_image, prefix_dict_keys, reduce_list_of_dicts, @@ -168,7 +167,7 @@ def predict( num_workers: int = 6, batch_size: int = 1, crop: Optional[str] = None, - ) -> np.ndarray: + ) -> List[np.ndarray]: self.logger.info("Predicting...") torch.cuda.empty_cache() self.ndim = dataset.ndim # type: ignore @@ -289,7 +288,7 @@ def training_loop( validation_func = functools.partial( self.validation_loop, validation_datasets, - loss_fns, + None, experiment_directory, num_workers=num_workers, ) @@ -305,7 +304,6 @@ def training_loop( validation_func(iter_idx) try: iteration_output = self._do_iteration(data, loss_fns, regularizer_fns=regularizer_fns) - output = iteration_output.output_image loss_dict = iteration_output.data_dict except (ProcessKilledException, TrainingException) as e: # If the process is killed, the DoIterationOutput @@ -373,14 +371,7 @@ def training_loop( loss_dict_reduced = communication.reduce_tensor_dict(loss_dict) loss_reduced = sum(loss_dict_reduced.values()) - metrics_dict = evaluate_dict( - metric_fns, - T.modulus_if_complex(output.detach()), - data["target"].detach().to(self.device), - reduction="mean", - ) - metrics_dict_reduced = communication.reduce_tensor_dict(metrics_dict) if metrics_dict else {} - storage.add_scalars(loss=loss_reduced, **loss_dict_reduced, **metrics_dict_reduced) + storage.add_scalars(loss=loss_reduced, **loss_dict_reduced) # Maybe not needed. del data diff --git a/direct/exceptions.py b/direct/exceptions.py index b55dc871..02a60d88 100644 --- a/direct/exceptions.py +++ b/direct/exceptions.py @@ -34,3 +34,13 @@ def __init__(self, message=None): self.logger.exception("TrainingException") else: self.logger.exception(f"TrainingException: {message}") + + +class ItemNotFoundException(DirectException): + def __init__(self, item_name, message=None): + super().__init__() + error_name = "".join([s.capitalize() for s in item_name.split(" ")]) + "Exception" + if message: + self.logger.exception(error_name) + else: + self.logger.exception("%s: %s", error_name, message) diff --git a/direct/functionals/__init__.py b/direct/functionals/__init__.py index 2b6c845a..dcd13597 100644 --- a/direct/functionals/__init__.py +++ b/direct/functionals/__init__.py @@ -1,5 +1,9 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors + from direct.functionals.challenges import * +from direct.functionals.grad import * +from direct.functionals.nmae import NMAELoss +from direct.functionals.nmse import * from direct.functionals.psnr import * from direct.functionals.ssim import * diff --git a/direct/functionals/grad.py b/direct/functionals/grad.py new file mode 100644 index 00000000..cf871ea9 --- /dev/null +++ b/direct/functionals/grad.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +# Code was borrowed and reformatted from https://github.com/kornia/kornia/blob/master/kornia/filters/sobel.py +# part of "Kornia: an Open Source Differentiable Computer Vision Library for PyTorch" with an Apache License. + +from enum import Enum +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ["SobelGradL1Loss", "SobelGradL2Loss"] + + +def get_sobel_kernel2d() -> torch.Tensor: + r"""Returns the Sobel kernel matrices :math:`G_{x}` and :math:`G_{y}`: + + ..math:: + + G_{x} = \begin{matrix} + -1 & 0 & 1 \\ + -2 & 0 & 2 \\ + -1 & 0 & 1 + \end{matrix}, \quad + G_{y} = \begin{matrix} + -1 & -2 & -1 \\ + 0 & 0 & 0 \\ + 1 & 2 & 1 + \end{matrix}. + """ + kernel_x: torch.Tensor = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]) + kernel_y: torch.Tensor = kernel_x.transpose(0, 1) + return torch.stack([kernel_x, kernel_y]) + + +def normalize_kernel(input: torch.Tensor) -> torch.Tensor: + r"""Normalize both derivative kernel. + + Parameters + ---------- + input: torch.Tensor + + Returns + ------- + torch.Tensor + Normalized kernel. + """ + norm: torch.Tensor = input.abs().sum(dim=-1).sum(dim=-1) + return input / (norm.unsqueeze(-1).unsqueeze(-1)) + + +def spatial_gradient(input: torch.Tensor, normalized: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Computes the first order image derivatives in :math:`x` and :math:`y` directions using a Sobel operator. + + Parameters + ---------- + input: torch.Tensor + Input image tensor with shape :math:`(B, C, H, W)`. + normalized: bool + Whether the output is normalized. Default: True. + + Returns + ------- + grad_x, grad_y: (torch.Tensor, torch.Tensor) + The derivatives in :math:`x` and :math:`y:` directions of the input each of same shape as input. + """ + if not len(input.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") + # allocate kernel + kernel: torch.Tensor = get_sobel_kernel2d() + if normalized: + kernel = normalize_kernel(kernel) + + # prepare kernel + b, c, h, w = input.shape + tmp_kernel: torch.Tensor = kernel.to(input).detach() + tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1) + + # convolve input tensor with sobel kernel + kernel_flip: torch.Tensor = tmp_kernel.flip(-3) + + # Pad with "replicate for spatial dims, but with zeros for channel + spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] + padded_inp: torch.Tensor = F.pad(input.reshape(b * c, 1, h, w), spatial_pad, "replicate")[:, :, None] + + grad = F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, 2, h, w) + grad_x, grad_y = grad[:, :, 0], grad[:, :, 1] + + return (grad_x, grad_y) + + +class SobelGradLossType(str, Enum): + + l1 = "l1" + l2 = "l2" + + +class SobelGradLoss(nn.Module): + r"""Computes the sum of the l1-loss between the gradient of input and target: + + It returns + + .. math :: + + ||u_x - v_x ||_k^k + ||u_y - v_y||_k^k + + where :math:`u` and :math:`v` denote the input and target images and :math:`k` is 1 if `type_loss`="l1" or 2 if + `type_loss`="l2". The gradients w.r.t. to :math:`x` and :math:`y` directions are computed using the Sobel operators. + """ + + def __init__(self, type_loss: SobelGradLossType, reduction: str = "mean", normalized_grad: bool = True): + """Inits :class:`SobelGradLoss`. + + Parameters + ---------- + type_loss: SobelGradLossType + Type of loss to be used. Can be "l1" or "l2". + reduction: str + Loss reduction. Can be 'mean' or "sum". Default: "mean". + normalized_grad: bool + Whether the computed gradients are normalized. Default: True. + """ + super().__init__() + + self.reduction = reduction + if type_loss == "l1": + self.loss = nn.L1Loss(reduction=reduction) + else: + self.loss = nn.MSELoss(reduction=reduction) + self.normalized_grad = normalized_grad + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`SobelGradLoss`. + + Parameters + ---------- + input: torch.Tensor + Input tensor. + target: torch.Tensor + Target tensor. + + Returns + ------- + loss: torch.Tensor + Sum of the l1-loss between the gradient of input and target. + """ + input_grad_x, input_grad_y = spatial_gradient(input, self.normalized_grad) + target_grad_x, target_grad_y = spatial_gradient(target, self.normalized_grad) + return self.loss(input_grad_x, target_grad_x) + self.loss(input_grad_y, target_grad_y) + + +class SobelGradL1Loss(SobelGradLoss): + r"""Computes the sum of the l1-loss between the gradient of input and target: + + It returns + + .. math :: + + ||u_x - v_x ||_1 + ||u_y - v_y||_1 + + where :math:`u` and :math:`v` denote the input and target images. The gradients w.r.t. to :math:`x` and :math:`y` + directions are computed using the Sobel operators. + """ + + def __init__(self, reduction: str = "mean", normalized_grad: bool = True): + """Inits :class:`SobelGradL1Loss`. + + Parameters + ---------- + reduction: str + Loss reduction. Can be 'mean' or "sum". Default: "mean". + normalized_grad: bool + Whether the computed gradients are normalized. Default: True. + """ + super().__init__(SobelGradLossType.l1, reduction, normalized_grad) + + +class SobelGradL2Loss(SobelGradLoss): + r"""Computes the sum of the l1-loss between the gradient of input and target: + + It returns + + .. math :: + + ||u_x - v_x ||_2^2 + ||u_y - v_y||_2^2 + + where :math:`u` and :math:`v` denote the input and target images. The gradients w.r.t. to :math:`x` and :math:`y` + directions are computed using the Sobel operators. + """ + + def __init__(self, reduction: str = "mean", normalized_grad: bool = True): + """Inits :class:`SobelGradL2Loss`. + + Parameters + ---------- + reduction: str + Loss reduction. Can be 'mean' or "sum". Default: "mean". + normalized_grad: bool + Whether the computed gradients are normalized. Default: True. + """ + super().__init__(SobelGradLossType.l2, reduction, normalized_grad) diff --git a/direct/functionals/nmae.py b/direct/functionals/nmae.py new file mode 100644 index 00000000..7e8a3ce9 --- /dev/null +++ b/direct/functionals/nmae.py @@ -0,0 +1,43 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import torch +import torch.nn as nn + +__all__ = ["NMAELoss"] + + +class NMAELoss(nn.Module): + """Computes the Normalized Mean Absolute Error (NMAE), i.e.: + + .. math:: + \frac{||u - v||_1}{||u||_1}, + + where :math:`u` and :math:`v` denote the target and the input. + """ + + def __init__(self, reduction="mean"): + """Inits :class:`NMAE` + + Parameters + ---------- + reduction: str + Specifies the reduction to apply to the output. Can be "none", "mean" or "sum". + Note that "mean" or "sum" will yield the same output. Default: "mean". + """ + super().__init__() + self.mae_loss = nn.L1Loss(reduction=reduction) + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """Forward method of :class:`NMAE`. + + Parameters + ---------- + input: torch.Tensor + Tensor of shape (*), where * means any number of dimensions. + target: torch.Tensor + Tensor of same shape as the input. + """ + return self.mae_loss(input, target) / self.mae_loss( + torch.zeros_like(target, dtype=target.dtype, device=target.device), target + ) diff --git a/direct/functionals/nmse.py b/direct/functionals/nmse.py new file mode 100644 index 00000000..d36acb38 --- /dev/null +++ b/direct/functionals/nmse.py @@ -0,0 +1,80 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import torch +import torch.nn as nn + +__all__ = ["NMSELoss", "NRMSELoss"] + + +class NMSELoss(nn.Module): + """Computes the Normalized Mean Squared Error (NMSE), i.e.: + + .. math:: + \frac{||u - v||_2^2}{||u||_2^2}, + + where :math:`u` and :math:`v` denote the target and the input. + """ + + def __init__(self, reduction="mean"): + """Inits :class:`NMSE` + + Parameters + ---------- + reduction: str + Specifies the reduction to apply to the output. Can be "none", "mean" or "sum". + Note that "mean" or "sum" will yield the same output. Default: "mean". + """ + super().__init__() + self.mse_loss = nn.MSELoss(reduction=reduction) + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """Forward method of :class:`NMSE`. + + Parameters + ---------- + input: torch.Tensor + Tensor of shape (*), where * means any number of dimensions. + target: torch.Tensor + Tensor of same shape as the input. + """ + return self.mse_loss(input, target) / self.mse_loss( + torch.zeros_like(target, dtype=target.dtype, device=target.device), target + ) + + +class NRMSELoss(nn.Module): + """Computes the Normalized Root Mean Squared Error (NRMSE), i.e.: + + .. math:: + \frac{||u - v||_2}{||u||_2}, + + where :math:`u` and :math:`v` denote the target and the input. + """ + + def __init__(self, reduction="mean"): + """Inits :class:`NRMSE` + + Parameters + ---------- + reduction: str + Specifies the reduction to apply to the output. Can be "none", "mean" or "sum". + Note that "mean" or "sum" will yield the same output. Default: "mean". + """ + super().__init__() + self.mse_loss = nn.MSELoss(reduction=reduction) + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """Forward method of :class:`NRMSE`. + + Parameters + ---------- + input: torch.Tensor + Tensor of shape (*), where * means any number of dimensions. + target: torch.Tensor + Tensor of same shape as the input. + """ + return torch.sqrt( + self.mse_loss(input, target) + / self.mse_loss(torch.zeros_like(target, dtype=target.dtype, device=target.device), target) + ) diff --git a/direct/nn/cirim/cirim.py b/direct/nn/cirim/cirim.py index 9f5258c6..fbcc7884 100644 --- a/direct/nn/cirim/cirim.py +++ b/direct/nn/cirim/cirim.py @@ -116,7 +116,9 @@ class IndRNNCell(nn.Module): References ---------- - .. [1] Li, S. et al. (2018) ‘Independently Recurrent Neural Network (IndRNN): Building A Longer and Deeper RNN’, Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, (1), pp. 5457–5466. doi: 10.1109/CVPR.2018.00572. + .. [1] Li, S. et al. (2018) ‘Independently Recurrent Neural Network (IndRNN): Building A Longer and Deeper RNN’, + Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, (1), + pp. 5457–5466. doi: 10.1109/CVPR.2018.00572. """ def __init__( @@ -220,7 +222,9 @@ class CIRIM(nn.Module): References ---------- - .. [1] Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent Inference Machines for fast and robust accelerated MRI reconstruction’. Available at: https://arxiv.org/abs/2111.15498v1 + .. [1] Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent + Inference Machines for fast and robust accelerated MRI reconstruction’. + Available at: https://arxiv.org/abs/2111.15498v1 """ def __init__( @@ -355,7 +359,9 @@ class RIMBlock(nn.Module): References ---------- - .. [1] Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent Inference Machines for fast and robust accelerated MRI reconstruction’. Available at: https://arxiv.org/abs/2111.15498v1 + .. [1] Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently + Recurrent Inference Machines for fast and robust accelerated MRI reconstruction’. + Available at: https://arxiv.org/abs/2111.15498v1 """ def __init__( diff --git a/direct/nn/cirim/cirim_engine.py b/direct/nn/cirim/cirim_engine.py index ef248f78..745d2e7b 100644 --- a/direct/nn/cirim/cirim_engine.py +++ b/direct/nn/cirim/cirim_engine.py @@ -10,12 +10,7 @@ from direct.config import BaseConfig from direct.engine import DoIterationOutput from direct.nn.mri_models import MRIModelEngine - -from direct.utils import ( - detach_dict, - dict_to_device, - reduce_list_of_dicts, -) +from direct.utils import detach_dict, dict_to_device, reduce_list_of_dicts class CIRIMEngine(MRIModelEngine): @@ -90,11 +85,26 @@ def _do_iteration( for i, output_image_iter in enumerate(output_image_cascade): for key, value in loss_dict.items(): loss_dict[key] = ( - value + loss_fns[key](output_image_iter, **data, reduction="mean") * iter_loss_weights[i] + value + + loss_fns[key]( + output_image_iter, + data["target"], + reduction="mean", + reconstruction_size=data.get("reconstruction_size", None), + ) + * iter_loss_weights[i] ) for key, value in regularizer_dict.items(): - loss_dict[key] = value + loss_fns[key](output_image_iter, **data) * iter_loss_weights[i] + loss_dict[key] = ( + value + + loss_fns[key]( + output_image_iter, + data["target"], + reconstruction_size=data.get("reconstruction_size", None), + ) + * iter_loss_weights[i] + ) # Total length of the number of cascades and the number of iterations len_output_image = len(output_image) + len(output_image[0]) diff --git a/direct/nn/jointicnet/jointicnet_engine.py b/direct/nn/jointicnet/jointicnet_engine.py index 5db0aa14..bebf3886 100644 --- a/direct/nn/jointicnet/jointicnet_engine.py +++ b/direct/nn/jointicnet/jointicnet_engine.py @@ -1,17 +1,13 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch import nn -from torch.cuda.amp import autocast -import direct.data.transforms as T from direct.config import BaseConfig -from direct.engine import DoIterationOutput from direct.nn.mri_models import MRIModelEngine -from direct.utils import detach_dict, dict_to_device, reduce_list_of_dicts class JointICNetEngine(MRIModelEngine): @@ -38,75 +34,14 @@ def __init__( **models, ) - self._spatial_dims = (2, 3) + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: - def _do_iteration( - self, - data: Dict[str, torch.Tensor], - loss_fns: Optional[Dict[str, Callable]] = None, - regularizer_fns: Optional[Dict[str, Callable]] = None, - ) -> DoIterationOutput: - - # loss_fns can be done, e.g. during validation - if loss_fns is None: - loss_fns = {} - - if regularizer_fns is None: - regularizer_fns = {} - - loss_dicts = [] - regularizer_dicts = [] - - data = dict_to_device(data, self.device) - - # sensitivity_map of shape (batch, coil, height, width, complex=2) - sensitivity_map = data["sensitivity_map"].clone() - data["sensitivity_map"] = self.compute_sensitivity_map(sensitivity_map) - - with autocast(enabled=self.mixed_precision): - - output_image = self.model( - masked_kspace=data["masked_kspace"], - sampling_mask=data["sampling_mask"], - sensitivity_map=data["sensitivity_map"], - ) # shape (batch, height, width, complex=2) - - output_image = T.modulus(output_image) # shape (batch, height, width) - - loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} - regularizer_dict = { - k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() - } - - for key, value in loss_dict.items(): - loss_dict[key] = value + loss_fns[key]( - output_image, - **data, - reduction="mean", - ) - - for key, value in regularizer_dict.items(): - regularizer_dict[key] = value + regularizer_fns[key]( - output_image, - **data, - ) - - loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore - - if self.model.training: - self._scaler.scale(loss).backward() - - loss_dicts.append(detach_dict(loss_dict)) - regularizer_dicts.append( - detach_dict(regularizer_dict) - ) # Need to detach dict as this is only used for logging. + output_image = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + ) # shape (batch, height, width) - # Add the loss dicts. - loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") - regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + output_kspace = None - return DoIterationOutput( - output_image=output_image, - sensitivity_map=data["sensitivity_map"], - data_dict={**loss_dict, **regularizer_dict}, - ) + return output_image, output_kspace diff --git a/direct/nn/kikinet/kikinet_engine.py b/direct/nn/kikinet/kikinet_engine.py index 1dbf18fd..ef5a9a8b 100644 --- a/direct/nn/kikinet/kikinet_engine.py +++ b/direct/nn/kikinet/kikinet_engine.py @@ -1,21 +1,17 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch import nn -from torch.cuda.amp import autocast -import direct.data.transforms as T from direct.config import BaseConfig -from direct.engine import DoIterationOutput from direct.nn.mri_models import MRIModelEngine -from direct.utils import detach_dict, dict_to_device, reduce_list_of_dicts class KIKINetEngine(MRIModelEngine): - """XPDNet Engine.""" + """KIKINet Engine.""" def __init__( self, @@ -38,76 +34,15 @@ def __init__( **models, ) - self._spatial_dims = (2, 3) + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: - def _do_iteration( - self, - data: Dict[str, torch.Tensor], - loss_fns: Optional[Dict[str, Callable]] = None, - regularizer_fns: Optional[Dict[str, Callable]] = None, - ) -> DoIterationOutput: - - # loss_fns can be done, e.g. during validation - if loss_fns is None: - loss_fns = {} - - if regularizer_fns is None: - regularizer_fns = {} - - loss_dicts = [] - regularizer_dicts = [] - - data = dict_to_device(data, self.device) - - # sensitivity_map of shape (batch, coil, height, width, complex=2) - sensitivity_map = data["sensitivity_map"].clone() - data["sensitivity_map"] = self.compute_sensitivity_map(sensitivity_map) - - with autocast(enabled=self.mixed_precision): - - output_image = self.model( - masked_kspace=data["masked_kspace"], - sampling_mask=data["sampling_mask"], - sensitivity_map=data["sensitivity_map"], - scaling_factor=data["scaling_factor"], - ) # shape (batch, height, width, complex=2) - - output_image = T.modulus(output_image) # shape (batch, height, width) - - loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} - regularizer_dict = { - k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() - } - - for key, value in loss_dict.items(): - loss_dict[key] = value + loss_fns[key]( - output_image, - **data, - reduction="mean", - ) - - for key, value in regularizer_dict.items(): - regularizer_dict[key] = value + regularizer_fns[key]( - output_image, - **data, - ) - - loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore - - if self.model.training: - self._scaler.scale(loss).backward() - - loss_dicts.append(detach_dict(loss_dict)) - regularizer_dicts.append( - detach_dict(regularizer_dict) - ) # Need to detach dict as this is only used for logging. + output_image = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + scaling_factor=data["scaling_factor"], + ) # shape (batch, height, width, complex[=2]) - # Add the loss dicts. - loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") - regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + output_kspace = None - return DoIterationOutput( - output_image=output_image, - sensitivity_map=data["sensitivity_map"], - data_dict={**loss_dict, **regularizer_dict}, - ) + return output_image, output_kspace diff --git a/direct/nn/lpd/lpd_engine.py b/direct/nn/lpd/lpd_engine.py index aa770b43..3fca9c6b 100644 --- a/direct/nn/lpd/lpd_engine.py +++ b/direct/nn/lpd/lpd_engine.py @@ -1,17 +1,13 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch import nn -from torch.cuda.amp import autocast -import direct.data.transforms as T from direct.config import BaseConfig -from direct.engine import DoIterationOutput from direct.nn.mri_models import MRIModelEngine -from direct.utils import detach_dict, dict_to_device, reduce_list_of_dicts class LPDNetEngine(MRIModelEngine): @@ -38,75 +34,13 @@ def __init__( **models, ) - self._spatial_dims = (2, 3) + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: - def _do_iteration( - self, - data: Dict[str, torch.Tensor], - loss_fns: Optional[Dict[str, Callable]] = None, - regularizer_fns: Optional[Dict[str, Callable]] = None, - ) -> DoIterationOutput: - - # loss_fns can be done, e.g. during validation - if loss_fns is None: - loss_fns = {} - - if regularizer_fns is None: - regularizer_fns = {} - - loss_dicts = [] - regularizer_dicts = [] - - data = dict_to_device(data, self.device) - - # sensitivity_map of shape (batch, coil, height, width, complex=2) - sensitivity_map = data["sensitivity_map"].clone() - data["sensitivity_map"] = self.compute_sensitivity_map(sensitivity_map) - - with autocast(enabled=self.mixed_precision): - - output_image = self.model( - masked_kspace=data["masked_kspace"], - sampling_mask=data["sampling_mask"], - sensitivity_map=data["sensitivity_map"], - ) # shape (batch, height, width, complex=2) - - output_image = T.modulus(output_image) # shape (batch, height, width) + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) - loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} - regularizer_dict = { - k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() - } - - for key, value in loss_dict.items(): - loss_dict[key] = value + loss_fns[key]( - output_image, - **data, - reduction="mean", - ) - - for key, value in regularizer_dict.items(): - regularizer_dict[key] = value + regularizer_fns[key]( - output_image, - **data, - ) - - loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore - - if self.model.training: - self._scaler.scale(loss).backward() - - loss_dicts.append(detach_dict(loss_dict)) - regularizer_dicts.append( - detach_dict(regularizer_dict) - ) # Need to detach dict as this is only used for logging. - - # Add the loss dicts. - loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") - regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") - - return DoIterationOutput( - output_image=output_image, + output_image = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], sensitivity_map=data["sensitivity_map"], - data_dict={**loss_dict, **regularizer_dict}, - ) + ) # shape (batch, height, width) + return output_image, None diff --git a/direct/nn/mri_models.py b/direct/nn/mri_models.py index 79ea81b9..3b0e6781 100644 --- a/direct/nn/mri_models.py +++ b/direct/nn/mri_models.py @@ -6,57 +6,37 @@ import gc import pathlib import time -from abc import abstractmethod from collections import defaultdict from os import PathLike -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from torch import nn +from torch.cuda.amp import autocast from torch.nn import functional as F from torch.utils.data import DataLoader import direct.data.transforms as T from direct.config import BaseConfig from direct.engine import DoIterationOutput, Engine -from direct.functionals import SSIMLoss -from direct.utils import communication, merge_list_of_dicts, multiply_function, reduce_list_of_dicts +from direct.functionals import NMAELoss, NMSELoss, NRMSELoss, SobelGradL1Loss, SobelGradL2Loss, SSIMLoss +from direct.types import TensorOrNone +from direct.utils import ( + communication, + detach_dict, + dict_to_device, + merge_list_of_dicts, + multiply_function, + reduce_list_of_dicts, +) from direct.utils.communication import reduce_tensor_dict -def _crop_volume( - source: torch.Tensor, target: torch.Tensor, resolution: Union[List[int], Tuple[int, ...]] -) -> Tuple[torch.Tensor, torch.Tensor]: - """2D source/target cropper. - - Parameters - ---------- - source: torch.Tensor - Has shape (batch, height, width) - target: torch.Tensor - Has shape (batch, height, width) - resolution: list of ints or tuple of ints - Target resolution. - - Returns - ------- - (torch.Tensor, torch.Tensor) - """ - - if not resolution or all(_ == 0 for _ in resolution): - return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension. - - source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension. - target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension. - - return source_abs, target_abs - - class MRIModelEngine(Engine): """Engine for MRI models. - Each child class should implement their own :meth:`_do_iteration` method. + Each child class should implement their own :meth:`forward_function`. """ def __init__( @@ -78,7 +58,7 @@ def __init__( model: nn.Module Model. device: str - Device. Can be "cuda" or "cpu". + Device. Can be "cuda:{idx}" or "cpu". forward_operator: Callable, optional The forward operator. Default: None. backward_operator: Callable, optional @@ -97,135 +77,360 @@ def __init__( mixed_precision=mixed_precision, **models, ) - self._complex_dim = -1 + self._spatial_dims = (2, 3) self._coil_dim = 1 + self._complex_dim = -1 + + def forward_function(self, data: Dict[str, Any]) -> Tuple[TensorOrNone, TensorOrNone]: + """This method performs the model's forward method given `data` which contains all tensor inputs. + + Must be implemented by child classes. + """ + raise NotImplementedError("Must be implemented by child class.") - @abstractmethod def _do_iteration( self, - data: Dict[str, torch.Tensor], + data: Dict[str, Any], loss_fns: Optional[Dict[str, Callable]] = None, regularizer_fns: Optional[Dict[str, Callable]] = None, ) -> DoIterationOutput: - """To be implemented by child class. + """Performs forward method and calculates loss functions. - Should output a :meth:`DoIterationOutput` object with `output_image`, `sensitivity_map` and - `data_dict` attributes. + Parameters + ---------- + data : Dict[str, Any] + Data containing keys with values tensors such as k-space, image, sensitivity map, etc. + loss_fns : Optional[Dict[str, Callable]] + Callable loss functions. + regularizer_fns : Optional[Dict[str, Callable]] + Callable regularization functions. + + Returns + ------- + DoIterationOutput + Contains outputs. """ + # loss_fns can be None, e.g. during validation + if loss_fns is None: + loss_fns = {} + + if regularizer_fns is None: + regularizer_fns = {} + + data = dict_to_device(data, self.device) + + output_image: TensorOrNone + output_kspace: TensorOrNone + + with autocast(enabled=self.mixed_precision): + + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) + + output_image, output_kspace = self.forward_function(data) + output_image = T.modulus_if_complex(output_image, complex_axis=self._complex_dim) + + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + regularizer_dict = { + k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() + } + loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, output_kspace) + regularizer_dict = self.compute_loss_on_data( + regularizer_dict, regularizer_fns, data, output_image, output_kspace + ) + + loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore + + if self.model.training: + self._scaler.scale(loss).backward() + + loss_dict = detach_dict(loss_dict) # Detach dict, only used for logging. + regularizer_dict = detach_dict(regularizer_dict) + + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict, **regularizer_dict}, + ) + def build_loss(self) -> Dict: - # TODO: Cropper is a processing output tool. - def get_resolution(**data): - """Be careful that this will use the cropping size of the FIRST sample in the batch.""" - return _compute_resolution(self.cfg.training.loss.crop, data.get("reconstruction_size", None)) + def get_resolution(reconstruction_size): + return _compute_resolution(self.cfg.training.loss.crop, reconstruction_size) # type: ignore + + def nmae_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate NMAE loss given source and target. + + Parameters + ---------- + source: torch.Tensor + Has shape (batch, *). + target: torch.Tensor + Has shape (batch, *). + reduction: str + Reduction type. Can be "sum" or "mean". + + Returns + ------- + nmae_loss: torch.Tensor + NMAE loss. + """ + if reconstruction_size is not None: + resolution = get_resolution(reconstruction_size) + source, target = _crop_volume(source, target, resolution) + + nmae_loss = NMAELoss(reduction=reduction).forward(source, target) + + return nmae_loss + + def nmse_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate NMSE loss given source and target. + + Parameters + ---------- + source: torch.Tensor + Has shape (batch, *). + target: torch.Tensor + Has shape (batch, *). + reduction: str + Reduction type. Can be "sum" or "mean". + + Returns + ------- + nmse_loss: torch.Tensor + NMSE loss. + """ + if reconstruction_size is not None: + resolution = get_resolution(reconstruction_size) + source, target = _crop_volume(source, target, resolution) + nmse_loss = NMSELoss(reduction=reduction).forward(source, target) - # TODO(jt) Ideally this is also configurable: - # - Do in steps (use insertation order) - # Crop -> then loss. + return nmse_loss - def l1_loss(source: torch.Tensor, reduction: str = "mean", **data) -> torch.Tensor: - """Calculate L1 loss given source and target. + def nrmse_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate NRMSE loss given source and target. Parameters ---------- source: torch.Tensor - Has shape (batch, [complex=2,] height, width) + Has shape (batch, *). + target: torch.Tensor + Has shape (batch, *). reduction: str Reduction type. Can be "sum" or "mean". - data: Dict[str, torch.Tensor] - Contains key "target" with value a tensor of shape (batch, height, width) + + Returns + ------- + nrmse_loss: torch.Tensor + NRMSE loss. + """ + if reconstruction_size is not None: + resolution = get_resolution(reconstruction_size) + source, target = _crop_volume(source, target, resolution) + nrmse_loss = NRMSELoss(reduction=reduction).forward(source, target) + + return nrmse_loss + + def l1_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate L1 loss given source image and target. + + Parameters + ---------- + source: torch.Tensor + Source tensor of shape (batch, *). + target: torch.Tensor + Target tensor of shape (batch, *). + reduction: str + Reduction type. Can be "sum" or "mean". + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. Returns ------- l1_loss: torch.Tensor L1 loss. """ - resolution = get_resolution(**data) - l1_loss = F.l1_loss( - *_crop_volume( - T.modulus_if_complex(source, complex_axis=self._complex_dim), data["target"], resolution - ), - reduction=reduction, - ) + if reconstruction_size is not None: + resolution = get_resolution(reconstruction_size) + source, target = _crop_volume(source, target, resolution) + l1_loss = F.l1_loss(source, target, reduction=reduction) return l1_loss - def l2_loss(source: torch.Tensor, reduction: str = "mean", **data) -> torch.Tensor: - """Calculate L2 loss (MSE) given source and target. + def l2_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate L2 loss (MSE) given source image and and `data` containing target. Parameters ---------- source: torch.Tensor - Has shape (batch, [complex=2,] height, width) + Source tensor of shape (batch, *). + target: torch.Tensor + Target tensor of shape (batch, *). reduction: str Reduction type. Can be "sum" or "mean". - data: Dict[str, torch.Tensor] - Contains key "target" with value a tensor of shape (batch, height, width) + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. Returns ------- l2_loss: torch.Tensor L2 loss. """ - resolution = get_resolution(**data) - l2_loss = F.mse_loss( - *_crop_volume( - T.modulus_if_complex(source, complex_axis=self._complex_dim), data["target"], resolution - ), - reduction=reduction, - ) + if reconstruction_size is not None: + resolution = get_resolution(reconstruction_size) + source, target = _crop_volume(source, target, resolution) + l2_loss = F.mse_loss(source, target, reduction=reduction) return l2_loss - def ssim_loss(source: torch.Tensor, reduction: str = "mean", **data) -> torch.Tensor: - """Calculate SSIM loss given source and target. + def ssim_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate SSIM loss given source image and target image. Parameters ---------- source: torch.Tensor - Has shape (batch, [complex=2,] height, width) + Source tensor of shape (batch, height, width, [complex=2]). + target: torch.Tensor + Target tensor of shape (batch, height, width, [complex=2]). reduction: str Reduction type. Can be "sum" or "mean". - data: Dict[str, torch.Tensor] - Contains key "target" with value a tensor of shape (batch, height, width) + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. Returns ------- ssim_loss: torch.Tensor SSIM loss. """ - resolution = get_resolution(**data) + resolution = get_resolution(reconstruction_size) if reduction != "mean": raise AssertionError( f"SSIM loss can only be computed with reduction == 'mean'." f" Got reduction == {reduction}." ) - source_abs, target_abs = _crop_volume( - T.modulus_if_complex(source, complex_axis=self._complex_dim), data["target"], resolution - ) + source_abs, target_abs = _crop_volume(source, target, resolution) data_range = torch.tensor([target_abs.max()], device=target_abs.device) ssim_loss = SSIMLoss().to(source_abs.device).forward(source_abs, target_abs, data_range=data_range) return ssim_loss + def grad_l1_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate Sobel gradient L1 loss given source image and target image. + + Parameters + ---------- + source: torch.Tensor + Source tensor of shape (batch, height, width, [complex=2]). + target: torch.Tensor + Target tensor of shape (batch, height, width, [complex=2]). + reduction: str + Reduction type. Can be "sum" or "mean". + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. + + Returns + ------- + grad_loss: torch.Tensor + Sobel grad L1 loss. + """ + resolution = get_resolution(reconstruction_size) + source_abs, target_abs = _crop_volume(source, target, resolution) + grad_l1_loss = SobelGradL1Loss(reduction).to(source_abs.device).forward(source_abs, target_abs) + + return grad_l1_loss + + def grad_l2_loss( + source: torch.Tensor, + target: torch.Tensor, + reduction: str = "mean", + reconstruction_size: Optional[Tuple] = None, + ) -> torch.Tensor: + """Calculate Sobel gradient L2 loss given source image and target image. + + Parameters + ---------- + source: torch.Tensor + Source tensor of shape (batch, height, width, [complex=2]). + target: torch.Tensor + Target tensor of shape (batch, height, width, [complex=2]). + reduction: str + Reduction type. Can be "sum" or "mean". + reconstruction_size: Optional[Tuple] + Reconstruction size to center crop. Default: None. + + Returns + ------- + grad_loss: torch.Tensor + Sobel grad L1 loss. + """ + resolution = get_resolution(reconstruction_size) + source_abs, target_abs = _crop_volume(source, target, resolution) + grad_l2_loss = SobelGradL2Loss(reduction).to(source_abs.device).forward(source_abs, target_abs) + + return grad_l2_loss + # Build losses loss_dict = {} for curr_loss in self.cfg.training.loss.losses: # type: ignore loss_fn = curr_loss.function - if loss_fn == "l1_loss": + if loss_fn in ["l1_loss", "kspace_l1_loss"]: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss) - elif loss_fn == "l2_loss": + elif loss_fn in ["l2_loss", "kspace_l2_loss"]: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss) elif loss_fn == "ssim_loss": loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss) + elif loss_fn == "grad_l1_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, grad_l1_loss) + elif loss_fn == "grad_l2_loss": + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, grad_l2_loss) + elif loss_fn in ["nmse_loss", "kspace_nmse_loss"]: + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, nmse_loss) + elif loss_fn in ["nrmse_loss", "kspace_nrmse_loss"]: + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, nrmse_loss) + elif loss_fn in ["nmae_loss", "kspace_nmae_loss"]: + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, nmae_loss) else: raise ValueError(f"{loss_fn} not permissible.") return loss_dict def compute_sensitivity_map(self, sensitivity_map: torch.Tensor) -> torch.Tensor: - """Computes sensitivity maps :math:`\{S^k\}_{k=1}^{n_c}` if `sensitivity_model` is available. + r"""Computes sensitivity maps :math:`\{S^k\}_{k=1}^{n_c}` if `sensitivity_model` is available. :math:`\{S^k\}_{k=1}^{n_c}` are normalized such that @@ -288,7 +493,6 @@ def reconstruct_volumes( # type: ignore Yields ------ (curr_volume, [curr_target,] loss_dict_list, filename): torch.Tensor, [torch.Tensor,], dict, pathlib.Path - # TODO(jt): visualization should be a namedtuple or a dict or so """ # pylint: disable=too-many-locals, arguments-differ self.models_to_device() @@ -477,6 +681,79 @@ def compute_model_per_coil(self, model_name: str, data: torch.Tensor) -> torch.T return torch.stack(output, dim=self._coil_dim) + def compute_loss_on_data( + self, + loss_dict: Dict[str, torch.Tensor], + loss_fns: Dict[str, Callable], + data: Dict[str, Any], + output_image: Optional[torch.Tensor] = None, + output_kspace: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + if output_image is None and output_kspace is None: + raise ValueError("Inputs for `output_image` and `output_kspace` cannot be both None.") + for key, value in loss_dict.items(): + if "kspace" in key: + if output_kspace is not None: + output, target, reconstruction_size = output_kspace, data["kspace"], None + else: + raise ValueError(f"Requested to compute `{key}` loss but received None for `output_kspace`.") + else: + if output_image is not None: + output, target, reconstruction_size = ( + output_image, + data["target"], + data.get("reconstruction_size", None), + ) + else: + raise ValueError(f"Requested to compute `{key}` loss but received None for `output_image`.") + loss_dict[key] = value + loss_fns[key](output, target, "mean", reconstruction_size) + return loss_dict + + def _forward_operator(self, image, sensitivity_map, sampling_mask): + return T.apply_mask( + self.forward_operator( + T.expand_operator(image, sensitivity_map, dim=self._coil_dim), + dim=self._spatial_dims, + ), + sampling_mask, + return_mask=False, + ) + + def _backward_operator(self, kspace, sensitivity_map, sampling_mask): + return T.reduce_operator( + self.backward_operator(T.apply_mask(kspace, sampling_mask, return_mask=False), dim=self._spatial_dims), + sensitivity_map, + dim=self._coil_dim, + ) + + +def _crop_volume( + source: torch.Tensor, target: torch.Tensor, resolution: Union[List[int], Tuple[int, ...]] +) -> Tuple[torch.Tensor, torch.Tensor]: + """2D source/target cropper. + + Parameters + ---------- + source: torch.Tensor + Has shape (batch, height, width) + target: torch.Tensor + Has shape (batch, height, width) + resolution: list of ints or tuple of ints + Target resolution. + + Returns + ------- + (torch.Tensor, torch.Tensor) + """ + + if not resolution or all(_ == 0 for _ in resolution): + return source.unsqueeze(1), target.unsqueeze(1) # Added channel dimension. + + source_abs = T.center_crop(source, resolution).unsqueeze(1) # Added channel dimension. + target_abs = T.center_crop(target, resolution).unsqueeze(1) # Added channel dimension. + + return source_abs, target_abs + def _process_output( data: torch.Tensor, @@ -548,7 +825,7 @@ def _compute_resolution( def _get_filename_from_batch(data: dict) -> pathlib.Path: - filenames = data.pop("filename") + filenames = data["filename"] if len(set(filenames)) != 1: raise ValueError( f"Expected a batch during validation to only contain filenames of one case. " f"Got {set(filenames)}." diff --git a/direct/nn/multidomainnet/multidomainnet_engine.py b/direct/nn/multidomainnet/multidomainnet_engine.py index 21980142..df8c5896 100644 --- a/direct/nn/multidomainnet/multidomainnet_engine.py +++ b/direct/nn/multidomainnet/multidomainnet_engine.py @@ -1,7 +1,7 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch import nn @@ -38,76 +38,16 @@ def __init__( **models, ) - self._spatial_dims = (2, 3) + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: - def _do_iteration( - self, - data: Dict[str, torch.Tensor], - loss_fns: Optional[Dict[str, Callable]] = None, - regularizer_fns: Optional[Dict[str, Callable]] = None, - ) -> DoIterationOutput: - - # loss_fns can be done, e.g. during validation - if loss_fns is None: - loss_fns = {} - - if regularizer_fns is None: - regularizer_fns = {} - - loss_dicts = [] - regularizer_dicts = [] - - data = dict_to_device(data, self.device) - - # sensitivity_map of shape (batch, coil, height, width, complex=2) - sensitivity_map = data["sensitivity_map"].clone() - data["sensitivity_map"] = self.compute_sensitivity_map(sensitivity_map) - - with autocast(enabled=self.mixed_precision): - - output_multicoil_image = self.model( - masked_kspace=data["masked_kspace"], - sensitivity_map=data["sensitivity_map"], - ) - - output_image = T.root_sum_of_squares( - output_multicoil_image, self._coil_dim, self._complex_dim - ) # shape (batch, height, width) - - loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} - regularizer_dict = { - k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() - } - - for key, value in loss_dict.items(): - loss_dict[key] = value + loss_fns[key]( - output_image, - **data, - reduction="mean", - ) - - for key, value in regularizer_dict.items(): - regularizer_dict[key] = value + regularizer_fns[key]( - output_image, - **data, - ) - - loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore - - if self.model.training: - self._scaler.scale(loss).backward() - - loss_dicts.append(detach_dict(loss_dict)) - regularizer_dicts.append( - detach_dict(regularizer_dict) - ) # Need to detach dict as this is only used for logging. - - # Add the loss dicts. - loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") - regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") - - return DoIterationOutput( - output_image=output_image, + output_multicoil_image = self.model( + masked_kspace=data["masked_kspace"], sensitivity_map=data["sensitivity_map"], - data_dict={**loss_dict, **regularizer_dict}, ) + output_image = T.root_sum_of_squares( + output_multicoil_image, self._coil_dim, self._complex_dim + ) # shape (batch, height, width) + + output_kspace = None + + return output_image, output_kspace diff --git a/direct/nn/recurrentvarnet/recurrentvarnet_engine.py b/direct/nn/recurrentvarnet/recurrentvarnet_engine.py index a024323d..56ade0d4 100644 --- a/direct/nn/recurrentvarnet/recurrentvarnet_engine.py +++ b/direct/nn/recurrentvarnet/recurrentvarnet_engine.py @@ -1,17 +1,14 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch import nn -from torch.cuda.amp import autocast import direct.data.transforms as T from direct.config import BaseConfig -from direct.engine import DoIterationOutput from direct.nn.mri_models import MRIModelEngine -from direct.utils import detach_dict, dict_to_device, reduce_list_of_dicts class RecurrentVarNetEngine(MRIModelEngine): @@ -38,78 +35,18 @@ def __init__( **models, ) - self._spatial_dims = (2, 3) + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: - def _do_iteration( - self, - data: Dict[str, torch.Tensor], - loss_fns: Optional[Dict[str, Callable]] = None, - regularizer_fns: Optional[Dict[str, Callable]] = None, - ) -> DoIterationOutput: - - # loss_fns can be done, e.g. during validation - if loss_fns is None: - loss_fns = {} - - if regularizer_fns is None: - regularizer_fns = {} - - loss_dicts = [] - regularizer_dicts = [] - - data = dict_to_device(data, self.device) - - # sensitivity_map of shape (batch, coil, height, width, complex=2) - sensitivity_map = data["sensitivity_map"].clone() - data["sensitivity_map"] = self.compute_sensitivity_map(sensitivity_map) - - with autocast(enabled=self.mixed_precision): - - output_kspace = self.model( - masked_kspace=data["masked_kspace"], - sampling_mask=data["sampling_mask"], - sensitivity_map=data["sensitivity_map"], - ) - - output_image = T.root_sum_of_squares( - self.backward_operator(output_kspace, dim=self._spatial_dims), # type: ignore - dim=self._coil_dim, - ) # shape (batch, height, width) - - loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} - regularizer_dict = { - k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() - } - - for key, value in loss_dict.items(): - loss_dict[key] = value + loss_fns[key]( - output_image, - **data, - reduction="mean", - ) - - for key, value in regularizer_dict.items(): - regularizer_dict[key] = value + regularizer_fns[key]( - output_image, - **data, - ) - - loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore - - if self.model.training: - self._scaler.scale(loss).backward() - - loss_dicts.append(detach_dict(loss_dict)) - regularizer_dicts.append( - detach_dict(regularizer_dict) - ) # Need to detach dict as this is only used for logging. - - # Add the loss dicts. - loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") - regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") - - return DoIterationOutput( - output_image=output_image, + output_kspace = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], sensitivity_map=data["sensitivity_map"], - data_dict={**loss_dict, **regularizer_dict}, ) + output_kspace = T.apply_padding(output_kspace, data.get("padding", None)) + + output_image = T.root_sum_of_squares( + self.backward_operator(output_kspace, dim=self._spatial_dims), + dim=self._coil_dim, + ) # shape (batch, height, width) + + return output_image, output_kspace diff --git a/direct/nn/rim/config.py b/direct/nn/rim/config.py index c6e53642..878f0410 100644 --- a/direct/nn/rim/config.py +++ b/direct/nn/rim/config.py @@ -25,22 +25,3 @@ class RIMConfig(ModelConfig): initializer_dilations: Tuple[int, ...] = (1, 1, 2, 4) initializer_multiscale: int = 1 normalized: bool = False - - -@dataclass -class RIM3dConfig(ModelConfig): - hidden_channels: int = 16 - length: int = 8 - depth: int = 2 - steps: int = 1 - no_parameter_sharing: bool = False - instance_norm: bool = False - dense_connect: bool = False - replication_padding: bool = True - image_initialization: str = "zero_filled" - # learned_initializer: bool = False - # initializer_channels: Tuple[int, ...] = (32, 32, 64, 64) - # initializer_dilations: Tuple[int, ...] = (1, 1, 2, 4) - # initializer_multiscale: int = 1 - z_reduction_frequency: int = 0 - kspace_context: int = 2 diff --git a/direct/nn/rim/rim_engine.py b/direct/nn/rim/rim_engine.py index 6deb7df2..a7da9384 100644 --- a/direct/nn/rim/rim_engine.py +++ b/direct/nn/rim/rim_engine.py @@ -7,6 +7,7 @@ from torch import nn from torch.cuda.amp import autocast +import direct.data.transforms as T from direct.config import BaseConfig from direct.engine import DoIterationOutput from direct.nn.mri_models import MRIModelEngine @@ -26,7 +27,25 @@ def __init__( mixed_precision: bool = False, **models: nn.Module, ): - """Inits :class:`RIMEngine.""" + """Inits :class:`RIMEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable, optional + The forward operator. Default: None. + backward_operator: Callable, optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ super().__init__( cfg, model, @@ -44,51 +63,41 @@ def _do_iteration( regularizer_fns: Optional[Dict[str, Callable]] = None, ) -> DoIterationOutput: - # loss_fns can be done, e.g. during validation if loss_fns is None: loss_fns = {} - if regularizer_fns is None: regularizer_fns = {} # The first input_image in the iteration is the input_image with the mask applied and no first hidden state. - input_image = None - hidden_state = None - output_image = None - loss_dicts = [] - regularizer_dicts = [] + input_image, hidden_state, output_image = None, None, None + loss_dicts, regularizer_dicts = [], [] data = dict_to_device(data, self.device) - # TODO(jt): keys=['sampling_mask', 'sensitivity_map', 'target', 'masked_kspace', 'scaling_factor'] - - # sensitivity_map of shape (batch, coil, height, width, complex=2) - sensitivity_map = data["sensitivity_map"].clone() if "noise_model" in self.models: raise NotImplementedError() - data["sensitivity_map"] = self.compute_sensitivity_map(sensitivity_map) - if self.cfg.model.scale_loglikelihood: # type: ignore scaling_factor = 1.0 * self.cfg.model.scale_loglikelihood / (data["scaling_factor"] ** 2) # type: ignore scaling_factor = scaling_factor.reshape(-1, 1) # shape (batch, complex=1) self.logger.debug(f"Scaling factor is: {scaling_factor}") else: # Needs fixing. - scaling_factor = torch.tensor([1.0]).to(sensitivity_map.device) # shape (complex=1, ) + scaling_factor = torch.tensor([1.0]).to(data["sensitivity_map"].device) # shape (complex=1, ) - for _ in range(self.cfg.model.steps): # type: ignore - with autocast(enabled=self.mixed_precision): + with autocast(enabled=self.mixed_precision): + # sensitivity_map of shape (batch, coil, height, width, complex=2) + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) + for _ in range(self.cfg.model.steps): # type: ignore if input_image is not None: input_image = input_image.permute(0, 2, 3, 1) reconstruction_iter, hidden_state = self.model( - **data, input_image=input_image, hidden_state=hidden_state, loglikelihood_scaling=scaling_factor, - ) - # reconstruction_iter: list with tensors of shape (batch, complex=2, height, width) - # hidden_state has shape: (batch, num_hidden_channels, height, width, depth) + **data, + ) # list with tensors of shape (batch, complex=2, height, width) + # hidden_state of shape (batch, num_hidden_channels, height, width, depth) output_image = reconstruction_iter[-1].permute(0, 2, 3, 1) # shape (batch, height, width, complex=2) @@ -99,23 +108,13 @@ def _do_iteration( k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() } - # TODO: This seems too similar not to be able to do this, perhaps a partial can help here for output_image_iter in reconstruction_iter: - output_image_iter = output_image_iter.permute( - 0, 2, 3, 1 - ) # shape (batch, height, width, complex=2) - for key, value in loss_dict.items(): - loss_dict[key] = value + loss_fns[key]( - output_image_iter, - **data, - reduction="mean", - ) - - for key, value in regularizer_dict.items(): - regularizer_dict[key] = value + regularizer_fns[key]( - output_image_iter, - **data, - ) + output_image_iter = T.modulus( + output_image_iter.permute(0, 2, 3, 1) + ) # shape (batch, height, width, 2) + + loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image_iter) + regularizer_dict = self.compute_loss_on_data(regularizer_dict, regularizer_fns, data, output_image) loss_dict = {k: v / len(reconstruction_iter) for k, v in loss_dict.items()} regularizer_dict = {k: v / len(reconstruction_iter) for k, v in regularizer_dict.items()} @@ -124,21 +123,17 @@ def _do_iteration( if self.model.training: # TODO(gy): With steps >= 1, calling .backward(retain_grad=False) caused problems. - # Check with Jonas if it's ok. - if (self.cfg.model.steps > 1) and (_ < self.cfg.model.steps - 1): # type: ignore self._scaler.scale(loss).backward(retain_graph=True) else: self._scaler.scale(loss).backward() # Detach hidden state from computation graph, to ensure loss is only computed per RIM block. - hidden_state = hidden_state.detach() # shape: (batch, num_hidden_channels, [slice,] height, width, depth) - input_image = output_image.detach() # shape (batch, complex=2, [slice,] height, width) + hidden_state = hidden_state.detach() # shape: (batch, num_hidden_channels, height, width, depth) + input_image = output_image.detach() # shape (batch, complex[=2], height, width) loss_dicts.append(detach_dict(loss_dict)) - regularizer_dicts.append( - detach_dict(regularizer_dict) - ) # Need to detach dict as this is only used for logging. + regularizer_dicts.append(detach_dict(regularizer_dict)) # Detach, only used for logging. # Add the loss dicts together over RIM steps, divide by the number of steps. loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum", divisor=self.cfg.model.steps) # type: ignore diff --git a/direct/nn/unet/config.py b/direct/nn/unet/config.py index aa3e46d8..19ccdc95 100644 --- a/direct/nn/unet/config.py +++ b/direct/nn/unet/config.py @@ -14,6 +14,15 @@ class UnetModel2dConfig(ModelConfig): dropout_probability: float = 0.0 +class NormUnetModel2dConfig(ModelConfig): + in_channels: int = 2 + out_channels: int = 2 + num_filters: int = 16 + num_pool_layers: int = 4 + dropout_probability: float = 0.0 + norm_groups: int = 2 + + @dataclass class Unet2dConfig(ModelConfig): num_filters: int = 16 diff --git a/direct/nn/unet/unet_2d.py b/direct/nn/unet/unet_2d.py index 566f9344..51c5c97e 100644 --- a/direct/nn/unet/unet_2d.py +++ b/direct/nn/unet/unet_2d.py @@ -446,7 +446,7 @@ def forward( sensitivity_map=sensitivity_map, ) elif self.image_initialization == "zero_filled": - input_image = self.backward_operator(masked_kspace).sum(self._coil_dim) + input_image = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim) else: raise ValueError( f"Unknown image_initialization. Expected `sense` or `zero_filled`. " diff --git a/direct/nn/unet/unet_engine.py b/direct/nn/unet/unet_engine.py index e1890846..96022c5f 100644 --- a/direct/nn/unet/unet_engine.py +++ b/direct/nn/unet/unet_engine.py @@ -1,17 +1,14 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch import nn -from torch.cuda.amp import autocast import direct.data.transforms as T from direct.config import BaseConfig -from direct.engine import DoIterationOutput from direct.nn.mri_models import MRIModelEngine -from direct.utils import detach_dict, dict_to_device, reduce_list_of_dicts class Unet2dEngine(MRIModelEngine): @@ -38,76 +35,16 @@ def __init__( **models, ) - self._spatial_dims = (2, 3) + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: - def _do_iteration( - self, - data: Dict[str, torch.Tensor], - loss_fns: Optional[Dict[str, Callable]] = None, - regularizer_fns: Optional[Dict[str, Callable]] = None, - ) -> DoIterationOutput: - - # loss_fns can be done, e.g. during validation - if loss_fns is None: - loss_fns = {} - - if regularizer_fns is None: - regularizer_fns = {} - - loss_dicts = [] - regularizer_dicts = [] - - data = dict_to_device(data, self.device) - - if self.cfg.model.image_initialization == "sense": # type: ignore - # sensitivity_map of shape (batch, coil, height, width, complex=2) - sensitivity_map = data["sensitivity_map"].clone() - data["sensitivity_map"] = self.compute_sensitivity_map(sensitivity_map) - - with autocast(enabled=self.mixed_precision): - - output_image = self.model( - masked_kspace=data["masked_kspace"], - sensitivity_map=data["sensitivity_map"] - if self.cfg.model.image_initialization == "sense" # type: ignore - else None, - ) - output_image = T.modulus(output_image) - - loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} - regularizer_dict = { - k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() - } - - for key, value in loss_dict.items(): - loss_dict[key] = value + loss_fns[key]( - output_image, - **data, - reduction="mean", - ) - - for key, value in regularizer_dict.items(): - regularizer_dict[key] = value + regularizer_fns[key]( - output_image, - **data, - ) - - loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore - - if self.model.training: - self._scaler.scale(loss).backward() - - loss_dicts.append(detach_dict(loss_dict)) - regularizer_dicts.append( - detach_dict(regularizer_dict) - ) # Need to detach dict as this is only used for logging. + output_image = self.model( + masked_kspace=data["masked_kspace"], + sensitivity_map=data["sensitivity_map"] + if self.cfg.model.image_initialization == "sense" # type: ignore + else None, + ) + output_image = T.modulus(output_image) - # Add the loss dicts. - loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") - regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + output_kspace = None - return DoIterationOutput( - output_image=output_image, - sensitivity_map=data["sensitivity_map"], - data_dict={**loss_dict, **regularizer_dict}, - ) + return output_image, output_kspace diff --git a/direct/nn/varnet/varnet_engine.py b/direct/nn/varnet/varnet_engine.py index 75f96b61..3e958647 100644 --- a/direct/nn/varnet/varnet_engine.py +++ b/direct/nn/varnet/varnet_engine.py @@ -1,17 +1,14 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch import nn -from torch.cuda.amp import autocast import direct.data.transforms as T from direct.config import BaseConfig -from direct.engine import DoIterationOutput from direct.nn.mri_models import MRIModelEngine -from direct.utils import detach_dict, dict_to_device, reduce_list_of_dicts class EndToEndVarNetEngine(MRIModelEngine): @@ -38,77 +35,16 @@ def __init__( **models, ) - self._spatial_dims = (2, 3) + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: - def _do_iteration( - self, - data: Dict[str, torch.Tensor], - loss_fns: Optional[Dict[str, Callable]] = None, - regularizer_fns: Optional[Dict[str, Callable]] = None, - ) -> DoIterationOutput: - - # loss_fns can be done, e.g. during validation - if loss_fns is None: - loss_fns = {} - - if regularizer_fns is None: - regularizer_fns = {} - - loss_dicts = [] - regularizer_dicts = [] - - data = dict_to_device(data, self.device) - - # sensitivity_map of shape (batch, coil, height, width, complex=2) - sensitivity_map = data["sensitivity_map"].clone() - data["sensitivity_map"] = self.compute_sensitivity_map(sensitivity_map) - - with autocast(enabled=self.mixed_precision): - - output_kspace = self.model( - masked_kspace=data["masked_kspace"], - sampling_mask=data["sampling_mask"], - sensitivity_map=data["sensitivity_map"], - ) - - output_image = T.root_sum_of_squares( - self.backward_operator(output_kspace, dim=self._spatial_dims), dim=self._coil_dim # type: ignore - ) # shape (batch, height, width) - - loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} - regularizer_dict = { - k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() - } - - for key, value in loss_dict.items(): - loss_dict[key] = value + loss_fns[key]( - output_image, - **data, - reduction="mean", - ) - - for key, value in regularizer_dict.items(): - regularizer_dict[key] = value + regularizer_fns[key]( - output_image, - **data, - ) - - loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore - - if self.model.training: - self._scaler.scale(loss).backward() - - loss_dicts.append(detach_dict(loss_dict)) - regularizer_dicts.append( - detach_dict(regularizer_dict) - ) # Need to detach dict as this is only used for logging. - - # Add the loss dicts. - loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") - regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") - - return DoIterationOutput( - output_image=output_image, + output_kspace = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], sensitivity_map=data["sensitivity_map"], - data_dict={**loss_dict, **regularizer_dict}, ) + output_image = T.root_sum_of_squares( + self.backward_operator(output_kspace, dim=self._spatial_dims), + dim=self._coil_dim, + ) # shape (batch, height, width) + + return output_image, output_kspace diff --git a/direct/nn/xpdnet/xpdnet_engine.py b/direct/nn/xpdnet/xpdnet_engine.py index 5b555c75..f76c1af4 100644 --- a/direct/nn/xpdnet/xpdnet_engine.py +++ b/direct/nn/xpdnet/xpdnet_engine.py @@ -1,17 +1,13 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch import nn -from torch.cuda.amp import autocast -import direct.data.transforms as T from direct.config import BaseConfig -from direct.engine import DoIterationOutput from direct.nn.mri_models import MRIModelEngine -from direct.utils import detach_dict, dict_to_device, reduce_list_of_dicts class XPDNetEngine(MRIModelEngine): @@ -38,76 +34,15 @@ def __init__( **models, ) - self._spatial_dims = (2, 3) + def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: - def _do_iteration( - self, - data: Dict[str, torch.Tensor], - loss_fns: Optional[Dict[str, Callable]] = None, - regularizer_fns: Optional[Dict[str, Callable]] = None, - ) -> DoIterationOutput: - - # loss_fns can be done, e.g. during validation - if loss_fns is None: - loss_fns = {} - - if regularizer_fns is None: - regularizer_fns = {} - - loss_dicts = [] - regularizer_dicts = [] - - data = dict_to_device(data, self.device) - - # sensitivity_map of shape (batch, coil, height, width, complex=2) - sensitivity_map = data["sensitivity_map"].clone() - data["sensitivity_map"] = self.compute_sensitivity_map(sensitivity_map) - - with autocast(enabled=self.mixed_precision): - - output_image = self.model( - masked_kspace=data["masked_kspace"], - sampling_mask=data["sampling_mask"], - sensitivity_map=data["sensitivity_map"], - scaling_factor=data["scaling_factor"], - ) # shape (batch, height, width, complex=2) - - output_image = T.modulus(output_image) # shape (batch, height, width) - - loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} - regularizer_dict = { - k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() - } - - for key, value in loss_dict.items(): - loss_dict[key] = value + loss_fns[key]( - output_image, - **data, - reduction="mean", - ) - - for key, value in regularizer_dict.items(): - regularizer_dict[key] = value + regularizer_fns[key]( - output_image, - **data, - ) - - loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore - - if self.model.training: - self._scaler.scale(loss).backward() - - loss_dicts.append(detach_dict(loss_dict)) - regularizer_dicts.append( - detach_dict(regularizer_dict) - ) # Need to detach dict as this is only used for logging. + output_image = self.model( + masked_kspace=data["masked_kspace"], + sampling_mask=data["sampling_mask"], + sensitivity_map=data["sensitivity_map"], + scaling_factor=data["scaling_factor"], + ) # shape (batch, height, width, complex[=2]) - # Add the loss dicts. - loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum") - regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum") + output_kspace = None - return DoIterationOutput( - output_image=output_image, - sensitivity_map=data["sensitivity_map"], - data_dict={**loss_dict, **regularizer_dict}, - ) + return output_image, output_kspace diff --git a/direct/types.py b/direct/types.py index 4faa1504..03156446 100644 --- a/direct/types.py +++ b/direct/types.py @@ -11,3 +11,4 @@ PathOrString = Union[pathlib.Path, str] FileOrUrl = NewType("FileOrUrl", PathOrString) HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler] +TensorOrNone = Union[None, torch.Tensor] diff --git a/setup.py b/setup.py index 507ff173..d36ab996 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ def finalize_options(self): "numpy>=1.21.2", "h5py==3.3.0", "omegaconf==2.1.1", - "torch==1.11.0", + "torch>=1.10.2", "torchvision", "scikit-image>=0.19.0", "scikit-learn>=1.0.1", diff --git a/tests/tests_common/test_subsample.py b/tests/tests_common/test_subsample.py index 8fae1036..316619f3 100644 --- a/tests/tests_common/test_subsample.py +++ b/tests/tests_common/test_subsample.py @@ -191,7 +191,7 @@ def test_same_across_volumes_mask_spiral(shape, accelerations): @pytest.mark.parametrize( - "shape, accelerations, center_scales", + "shape, accelerations, center_fractions", [ ([4, 32, 32, 2], [4], [0.08]), ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), @@ -209,10 +209,10 @@ def test_same_across_volumes_mask_spiral(shape, accelerations): tuple(np.random.randint(100000, 1000000, 30)), ], ) -def test_apply_mask_poisson(shape, accelerations, center_scales, seed): +def test_apply_mask_poisson(shape, accelerations, center_fractions, seed): mask_func = VariableDensityPoissonMaskFunc( accelerations=accelerations, - center_scales=center_scales, + center_fractions=center_fractions, ) mask = mask_func(shape[1:], seed=seed) acs_mask = mask_func(shape[1:], seed=seed, return_acs=True) @@ -225,16 +225,16 @@ def test_apply_mask_poisson(shape, accelerations, center_scales, seed): @pytest.mark.parametrize( - "shape, accelerations, center_scales", + "shape, accelerations, center_fractions", [ ([4, 32, 32, 2], [4], [0.08]), ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), ], ) -def test_same_across_volumes_mask_spiral(shape, accelerations, center_scales): +def test_same_across_volumes_mask_spiral(shape, accelerations, center_fractions): mask_func = VariableDensityPoissonMaskFunc( accelerations=accelerations, - center_scales=center_scales, + center_fractions=center_fractions, ) num_slices = shape[0] masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] diff --git a/tests/tests_data/test_mri_transforms.py b/tests/tests_data/test_mri_transforms.py index ecefdcdf..06925fe8 100644 --- a/tests/tests_data/test_mri_transforms.py +++ b/tests/tests_data/test_mri_transforms.py @@ -11,8 +11,11 @@ from direct.data.mri_transforms import ( ApplyMask, + ApplyZeroPadding, Compose, ComputeImage, + ComputeScalingFactor, + ComputeZeroPadding, CreateSamplingMask, CropKspace, DeleteKeys, @@ -20,11 +23,13 @@ EstimateSensitivityMap, Normalize, PadCoilDimension, + ReconstructionType, ToTensor, WhitenData, build_mri_transforms, ) from direct.data.transforms import fft2, ifft2 +from direct.exceptions import ItemNotFoundException def create_sample(shape, **kwargs): @@ -77,6 +82,44 @@ def test_Compose(shape): assert torch.allclose(compose_out, kspace) +@pytest.mark.parametrize( + "shape", + [(5, 7, 6), (3, 4, 6, 4)], +) +def test_ComputeZeroPadding(shape): + + sample = create_sample(shape + (2,)) + + pad_shape = [1 for _ in range(len(sample["kspace"].shape))] + pad_shape[1:-1] = sample["kspace"].shape[1:-1] + padding = torch.from_numpy(np.random.randn(*pad_shape)).round().bool() + sample["kspace"] = (~padding) * sample["kspace"] + + transform = ComputeZeroPadding() + sample = transform(sample) + + assert torch.allclose(sample["padding"], padding) + + +@pytest.mark.parametrize( + "shape", + [(5, 7, 6), (3, 4, 6, 4)], +) +def test_ApplyZeroPadding(shape): + + sample = create_sample(shape + (2,)) + pad_shape = [1 for _ in range(len(sample["kspace"].shape))] + pad_shape[1:-1] = sample["kspace"].shape[1:-1] + padding = torch.from_numpy(np.random.randn(*pad_shape)).round().bool() + sample.update({"padding": padding}) + + kspace = sample["kspace"] + transform = ApplyZeroPadding() + sample = transform(sample) + + assert torch.allclose(sample["kspace"], (~padding) * kspace) + + @pytest.mark.parametrize( "shape", [(1, 4, 6), (5, 7, 6), (2, None, None), (3, 4, 6, 4)], @@ -87,7 +130,7 @@ def test_Compose(shape): ) @pytest.mark.parametrize( "padding", - [None, [2, 2]], + [None, True], ) @pytest.mark.parametrize( "use_shape", @@ -97,17 +140,15 @@ def test_CreateSamplingMask(shape, return_acs, padding, use_shape): sample = create_sample(shape + (2,)) if padding: - sample.update({"padding_right": padding[0], "padding_left": padding[1]}) + pad_shape = [1 for _ in range(len(sample["kspace"].shape))] + pad_shape[1:-1] = sample["kspace"].shape[1:-1] + sample.update({"padding": torch.from_numpy(np.random.randn(*pad_shape))}) transform = CreateSamplingMask(mask_func=_mask_func, shape=shape[1:] if use_shape else None, return_acs=return_acs) - if padding and len(shape) > 3: - with pytest.raises(ValueError): - sample = transform(sample) - else: - sample = transform(sample) - assert "sampling_mask" in sample - assert tuple(sample["sampling_mask"].shape) == (1,) + sample["kspace"].shape[1:-1] + (1,) - if return_acs: - assert "acs_mask" in sample + sample = transform(sample) + assert "sampling_mask" in sample + assert tuple(sample["sampling_mask"].shape) == (1,) + sample["kspace"].shape[1:-1] + (1,) + if return_acs: + assert "acs_mask" in sample @pytest.mark.parametrize( @@ -118,7 +159,7 @@ def test_ApplyMask(shape): sample = create_sample(shape=shape + (2,)) transform = ApplyMask() # Check error raise when sampling mask not present in sample - with pytest.raises(AssertionError): + with pytest.raises(ValueError): sample = transform(sample) sample.update({"sampling_mask": torch.rand(shape[1:]).round().unsqueeze(0).unsqueeze(-1)}) sample = transform(sample) @@ -203,28 +244,25 @@ def test_CropKspace( ], ) @pytest.mark.parametrize( - "type_recon, complex_output, expect_error", + "type_recon, complex_output", [ - ["complex", True, False], - ["sense", True, False], - ["rss", False, False], - ["invalid", None, True], + [ReconstructionType.complex, True], + [ReconstructionType.complex_mod, False], + [ReconstructionType.sense, True], + [ReconstructionType.sense_mod, False], + [ReconstructionType.rss, False], ], ) -def test_ComputeImage(shape, spatial_dims, type_recon, complex_output, expect_error): +def test_ComputeImage(shape, spatial_dims, type_recon, complex_output): sample = create_sample(shape=shape + (2,)) - if expect_error: - with pytest.raises(ValueError): - transform = ComputeImage("kspace", "target", ifft2, type_reconstruction=type_recon) - else: - transform = ComputeImage("kspace", "target", ifft2, type_reconstruction=type_recon) - if type_recon == "sense": - with pytest.raises(ValueError): - sample = transform(sample, coil_dim=0, spatial_dims=spatial_dims) - sample.update({"sensitivity_map": torch.rand(shape + (2,))}) - sample = transform(sample, coil_dim=0, spatial_dims=spatial_dims) - assert "target" in sample - assert sample["target"].shape == (shape[1:] + (2,) if complex_output else shape[1:]) + transform = ComputeImage("kspace", "target", ifft2, type_reconstruction=type_recon) + if type_recon in ["sense", "sense_mod"]: + with pytest.raises(ItemNotFoundException): + sample = transform(sample, coil_dim=0, spatial_dims=spatial_dims) + sample.update({"sensitivity_map": torch.rand(shape + (2,))}) + sample = transform(sample, coil_dim=0, spatial_dims=spatial_dims) + assert "target" in sample + assert sample["target"].shape == (shape[1:] + (2,) if complex_output else shape[1:]) @pytest.mark.parametrize( @@ -358,7 +396,11 @@ def test_PadCoilDimension(shape, pad_coils, key): "percentile", [None, 0.9], ) -def test_Normalize(shape, normalize_key, percentile): +@pytest.mark.parametrize( + "norm_keys", + [None, ["kspace", "masked_kspace", "target"]], +) +def test_Normalize(shape, normalize_key, percentile, norm_keys): sample = create_sample( shape=shape + (2,), masked_kspace=torch.rand(shape + (2,)), @@ -366,7 +408,12 @@ def test_Normalize(shape, normalize_key, percentile): sampling_mask=torch.rand(shape[1:]).round().unsqueeze(0).unsqueeze(-1), scaling_factor=torch.rand(1), ) - transform = Normalize(normalize_key, percentile) + transform = Compose( + [ + ComputeScalingFactor(normalize_key, percentile, "scaling_factor"), + Normalize("scaling_factor", keys_to_normalize=norm_keys), + ] + ) sample = transform(sample) assert "scaling_diff" in sample diff --git a/tests/tests_data/test_transforms.py b/tests/tests_data/test_transforms.py index fb007432..e4a4696c 100644 --- a/tests/tests_data/test_transforms.py +++ b/tests/tests_data/test_transforms.py @@ -481,3 +481,18 @@ def test_complex_center_crop(shape, crop_shape, contiguous): assert all(data.shape == tuple([data.shape[0]] + crop_shape + [2]) for data in data_list) if contiguous: assert all(data.is_contiguous() for data in data_list) + + +@pytest.mark.parametrize( + "shape", + [ + [5, 10, 20, 22], + [1, 10, 20, 22], + ], +) +def test_apply_padding(shape): + data = create_input(shape + [2]) + padding = torch.from_numpy(np.random.randn(shape[0], 1, shape[-2], shape[-1], 1)).round().bool() + padded_data = transforms.apply_padding(data, padding) + + assert torch.allclose(data * (~padding), padded_data) diff --git a/tests/tests_functionals/test_gradloss.py b/tests/tests_functionals/test_gradloss.py new file mode 100644 index 00000000..4b465c10 --- /dev/null +++ b/tests/tests_functionals/test_gradloss.py @@ -0,0 +1,33 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import numpy as np +import pytest +import torch +from skimage.color import rgb2gray +from sklearn.datasets import load_sample_image + +from direct.functionals import SobelGradL1Loss, SobelGradL2Loss + +# Load two images and convert them to grayscale +flower = rgb2gray(load_sample_image("flower.jpg"))[None].astype(np.float32) +china = rgb2gray(load_sample_image("china.jpg"))[None].astype(np.float32) + + +@pytest.mark.parametrize("image", [flower, china]) +def test_nmse(image): + image_batch = [] + image_noise_batch = [] + + for sigma in range(1, 5): + noise = sigma * np.random.rand(*image.shape) + image_noise = (image + noise).astype(np.float32).clip(0, 255) + + image_batch.append(image) + image_noise_batch.append(image_noise) + + image_batch_torch = torch.tensor(image_batch) + image_noise_batch_torch = torch.tensor(image_noise_batch) + + grad_loss_l1 = SobelGradL1Loss(image_batch_torch, image_noise_batch_torch) + grad_loss_l2 = SobelGradL2Loss(image_batch_torch, image_noise_batch_torch) diff --git a/tests/tests_functionals/test_nmae.py b/tests/tests_functionals/test_nmae.py new file mode 100644 index 00000000..7c5a12f7 --- /dev/null +++ b/tests/tests_functionals/test_nmae.py @@ -0,0 +1,39 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import numpy as np +import pytest +import torch +from skimage.color import rgb2gray +from sklearn.datasets import load_sample_image + +from direct.functionals.nmae import NMAELoss + +# Load two images and convert them to grayscale +flower = rgb2gray(load_sample_image("flower.jpg"))[None].astype(np.float32) +china = rgb2gray(load_sample_image("china.jpg"))[None].astype(np.float32) + + +@pytest.mark.parametrize("image", [flower, china]) +def test_nmae(image): + image_batch = [] + image_noise_batch = [] + single_image_nmse = [] + + for sigma in range(1, 5): + noise = sigma * np.random.rand(*image.shape) + image_noise = (image + noise).astype(np.float32).clip(0, 255) + + image_batch.append(image) + image_noise_batch.append(image_noise) + + image_batch = np.stack(image_batch) + image_noise_batch = np.stack(image_noise_batch) + + np_nmae = np.abs(image_batch - image_noise_batch).mean() / np.abs(image_batch).mean() + + image_batch_torch = torch.tensor(image_batch) + image_noise_batch_torch = torch.tensor(image_noise_batch) + + nmse_batch = NMAELoss().forward(image_noise_batch_torch, image_batch_torch) + assert np.allclose(nmse_batch, np_nmae, atol=5e-4) diff --git a/tests/tests_functionals/test_nmse.py b/tests/tests_functionals/test_nmse.py index 8ee18c90..6422d772 100644 --- a/tests/tests_functionals/test_nmse.py +++ b/tests/tests_functionals/test_nmse.py @@ -9,6 +9,7 @@ from sklearn.datasets import load_sample_image from direct.functionals.challenges import fastmri_nmse +from direct.functionals.nmse import NMSELoss, NRMSELoss # Load two images and convert them to grayscale flower = rgb2gray(load_sample_image("flower.jpg"))[None].astype(np.float32) @@ -16,7 +17,7 @@ @pytest.mark.parametrize("image", [flower, china]) -def test_fastmri_nmse(image): +def test_nmse(image): image_batch = [] image_noise_batch = [] single_image_nmse = [] @@ -39,7 +40,12 @@ def test_fastmri_nmse(image): image_batch_torch = torch.tensor(image_batch) image_noise_batch_torch = torch.tensor(image_noise_batch) - + # Test fastmri_nmse fastmri_nmse_batch = fastmri_nmse(image_batch_torch, image_noise_batch_torch) - assert np.allclose(fastmri_nmse_batch, skimage_nmse, atol=5e-4) + # Test NMSE loss + nmse_batch = NMSELoss().forward(image_noise_batch_torch, image_batch_torch) + assert np.allclose(nmse_batch, skimage_nmse, atol=5e-4) + # test NRMSE loss + nrmse_batch = NRMSELoss().forward(image_noise_batch_torch, image_batch_torch) + assert np.allclose(nrmse_batch**2, skimage_nmse, atol=5e-4) diff --git a/tests/tests_nn/test_jointicnet_engine.py b/tests/tests_nn/test_jointicnet_engine.py index d97e3b3b..a4850669 100644 --- a/tests/tests_nn/test_jointicnet_engine.py +++ b/tests/tests_nn/test_jointicnet_engine.py @@ -33,7 +33,7 @@ def create_sample(shape, **kwargs): @pytest.mark.parametrize("shape", [(4, 3, 10, 16, 2), (5, 1, 10, 12, 2)]) -@pytest.mark.parametrize("loss_fns", [["l1_loss", "ssim_loss", "l2_loss"]]) +@pytest.mark.parametrize("loss_fns", [["l1_loss", "ssim_loss", "l2_loss", "nrmse_loss", "nmae_loss", "grad_l1_loss"]]) @pytest.mark.parametrize("num_iter", [2, 3]) def test_jointicnet_engine( shape, diff --git a/tests/tests_nn/test_recurrentvarnet_engine.py b/tests/tests_nn/test_recurrentvarnet_engine.py index a94fcdb7..7ecc9226 100644 --- a/tests/tests_nn/test_recurrentvarnet_engine.py +++ b/tests/tests_nn/test_recurrentvarnet_engine.py @@ -15,10 +15,11 @@ def create_sample(shape, **kwargs): sample = dict() - sample["masked_kspace"] = torch.from_numpy(np.random.randn(*shape)).float() + sample["kspace"] = torch.from_numpy(np.random.randn(*shape)).float() sample["sensitivity_map"] = torch.from_numpy(np.random.randn(*shape)).float() - sample["sampling_mask"] = torch.from_numpy(np.random.randn(1, shape[1], shape[2], 1)).float() - sample["target"] = torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2])).float() + sample["sampling_mask"] = torch.from_numpy(np.random.randn(shape[0], 1, shape[2], shape[3], 1)).float() + sample["masked_kspace"] = sample["kspace"] * sample["sampling_mask"] + sample["target"] = torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float() sample["scaling_factor"] = torch.tensor([1.0]) for k, v in locals()["kwargs"].items(): sample[k] = v @@ -31,7 +32,21 @@ def create_sample(shape, **kwargs): ) @pytest.mark.parametrize( "loss_fns", - [["l1_loss", "ssim_loss", "l2_loss"]], + [ + [ + "l1_loss", + "ssim_loss", + "l2_loss", + "grad_l1_loss", + "grad_l2_loss", + "nrmse_loss", + "nmae_loss", + "nmse_loss", + "kspace_l1_loss", + "kspace_nrmse_loss", + ], + ["ssim_loss", "non_permissible_loss"], + ], ) @pytest.mark.parametrize( "num_steps", @@ -61,10 +76,13 @@ def test_recurrentvarnet_engine(shape, loss_fns, num_steps): # Test _do_iteration function with a single data batch data = create_sample( shape, - sampling_mask=torch.from_numpy(np.random.randn(1, 1, shape[2], shape[3], 1)).float(), target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float(), scaling_factor=torch.ones(shape[0]), ) - loss_fns = engine.build_loss() - out = engine._do_iteration(data, loss_fns) - assert out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) + if "non_permissible_loss" in loss_fns: + with pytest.raises(ValueError): + loss_fns = engine.build_loss() + else: + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + assert out.output_image.shape == (shape[0],) + tuple(shape[2:-1])