diff --git a/direct/common/subsample.py b/direct/common/subsample.py index b25496cc..91b90acb 100644 --- a/direct/common/subsample.py +++ b/direct/common/subsample.py @@ -7,19 +7,21 @@ import contextlib import logging -import pathlib from abc import abstractmethod +from enum import Enum from typing import List, Optional, Tuple import numpy as np import torch +import direct.data.transforms as T from direct.environment import DIRECT_CACHE_DIR from direct.types import Number from direct.utils import str_to_class from direct.utils.io import download_url logger = logging.getLogger(__name__) +GOLDEN_RATIO = (1 + np.sqrt(5)) / 2 @contextlib.contextmanager @@ -104,7 +106,7 @@ def __call__(self, data, seed=None, return_acs=False): ndarray """ self.rng.seed(seed) - mask = self.mask_func(data, return_acs=return_acs) # pylint: disable = E1123 + mask = self.mask_func(data, seed=seed, return_acs=return_acs) # pylint: disable = E1123 return mask @@ -145,7 +147,7 @@ def mask_func(self, shape, return_acs=False, seed=None): Parameters ---------- - data : iterable[int] + shape : iterable[int] The shape of the mask to be created. The shape should at least 3 dimensions. Samples are drawn along the second last dimension. seed : int (optional) @@ -230,7 +232,7 @@ def mask_func(self, shape, return_acs=False, seed=None): Parameters ---------- - data : iterable[int] + shape : iterable[int] The shape of the mask to be created. The shape should at least 3 dimensions. Samples are drawn along the second last dimension. seed : int (optional) @@ -314,7 +316,28 @@ def circular_centered_mask(shape, radius): mask = ((dist_from_center <= radius) * np.ones(shape)).astype(bool) return mask[np.newaxis, ..., np.newaxis] - def mask_func(self, shape, return_acs=False): + def mask_func(self, shape, return_acs=False, seed=None): + """ + Downloads and loads pre=computed Poisson masks. + Currently supports shapes of 218 x 170/ 218/ 174 and acceleration factors of 5 or 10. + + Parameters + ---------- + + shape : iterable[int] + The shape of the mask to be created. The shape should at least 3 dimensions. + Samples are drawn along the second last dimension. + seed : int (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. + return_acs : bool + Return the autocalibration signal region as a mask. + + Returns + ------- + torch.Tensor : the sampling mask + + """ shape = tuple(shape)[:-1] if return_acs: return torch.from_numpy(self.circular_centered_mask(shape, 18)) @@ -322,12 +345,14 @@ def mask_func(self, shape, return_acs=False): if shape not in self.shapes: raise ValueError(f"No mask of shape {shape} is available in the CalgaryCampinas dataset.") - acceleration = self.choose_acceleration() - masks = self.masks[acceleration] + with temp_seed(self.rng, seed): + acceleration = self.choose_acceleration() + masks = self.masks[acceleration] + + mask, num_masks = masks[shape] + # Randomly pick one example + choice = self.rng.randint(0, num_masks) - mask, num_masks = masks[shape] - # Randomly pick one example - choice = self.rng.randint(0, num_masks) return torch.from_numpy(mask[choice][np.newaxis, ..., np.newaxis]) def __load_masks(self, acceleration): @@ -352,6 +377,238 @@ def __load_masks(self, acceleration): return output +class CIRCUSSamplingMode(str, Enum): + + circus_radial = "circus-radial" + circus_spiral = "circus-spiral" + + +class CIRCUSMaskFunc(BaseMaskFunc): + """ + Implementation of Cartesian undersampling (radial or spiral) using CIRCUS as shown in [1]_. It creates + radial or spiral masks for Cartesian acquired data on a grid. + + References + ---------- + + .. [1] Liu J, Saloner D. Accelerated MRI with CIRcular Cartesian UnderSampling (CIRCUS): + a variable density Cartesian sampling strategy for compressed sensing and parallel imaging. + Quant Imaging Med Surg. 2014 Feb;4(1):57-67. doi: 10.3978/j.issn.2223-4292.2014.02.01. + PMID: 24649436; PMCID: PMC3947985. + + """ + + def __init__( + self, + accelerations, + subsampling_scheme: CIRCUSSamplingMode, + **kwargs, + ): + super().__init__( + accelerations=accelerations, + center_fractions=tuple(0 for _ in range(len(accelerations))), + uniform_range=False, + ) + if subsampling_scheme not in ["circus-spiral", "circus-radial"]: + raise NotImplementedError( + f"Currently CIRCUSMaskFunc is only implemented for 'circus-radial' or 'circus-spiral' " + f"as a subsampling_scheme. Got subsampling_scheme={subsampling_scheme}." + ) + + self.subsampling_scheme = "circus-radial" if subsampling_scheme is None else subsampling_scheme + + @staticmethod + def get_square_ordered_idxs(square_side_size: int, square_id: int) -> Tuple[Tuple, ...]: + """ + Returns ordered (clockwise) indices of a sub-square of a square matrix. + + Parameters: + ----------- + square_side_size: int + Square side size. Dim of array. + square_id: int + Number of sub-square. Can be 0, ..., square_side_size // 2. + + Returns: + -------- + ordered_idxs: List of tuples. + Indices of each point that belongs to the square_id-th sub-square + starting from top-left point clockwise. + + """ + assert square_id in range(square_side_size // 2) + + ordered_idxs = list() + + for col in range(square_id, square_side_size - square_id): + ordered_idxs.append((square_id, col)) + + for row in range(square_id + 1, square_side_size - (square_id + 1)): + ordered_idxs.append((row, square_side_size - (square_id + 1))) + + for col in range(square_side_size - (square_id + 1), square_id, -1): + ordered_idxs.append((square_side_size - (square_id + 1), col)) + + for row in range(square_side_size - (square_id + 1), square_id, -1): + ordered_idxs.append((row, square_id)) + + return tuple(ordered_idxs) + + def circus_radial_mask(self, shape, acceleration): + """ + Implements CIRCUS radial undersampling. + """ + max_dim = max(shape) - max(shape) % 2 + min_dim = min(shape) - min(shape) % 2 + num_nested_squares = max_dim // 2 + M = int(np.prod(shape) / (acceleration * (max_dim / 2 - (max_dim - min_dim) * (1 + min_dim / max_dim) / 4))) + + mask = np.zeros((max_dim, max_dim), dtype=np.float32) + + t = self.rng.randint(low=0, high=1e4, size=1, dtype=int).item() + + for square_id in range(num_nested_squares): + ordered_indices = self.get_square_ordered_idxs( + square_side_size=max_dim, + square_id=square_id, + ) + # J: size of the square, J=2,…,N, i.e., the number of points along one side of the square + J = 2 * (num_nested_squares - square_id) + # K: total number of points along the perimeter of the square K=4·J-4; + K = 4 * (J - 1) + + for m in range(M): + indices_idx = int(np.floor(np.mod((m + t * M) / GOLDEN_RATIO, 1) * K)) + mask[ordered_indices[indices_idx]] = 1.0 + + pad = ((shape[0] % 2, 0), (shape[1] % 2, 0)) + + mask = np.pad(mask, pad, constant_values=0) + mask = T.center_crop(torch.from_numpy(mask.astype(bool)), shape) + + return mask + + def circus_spiral_mask(self, shape, acceleration): + """ + Implements CIRCUS spiral undersampling. + """ + max_dim = max(shape) - max(shape) % 2 + min_dim = min(shape) - min(shape) % 2 + + num_nested_squares = max_dim // 2 + + M = int(np.prod(shape) / (acceleration * (max_dim / 2 - (max_dim - min_dim) * (1 + min_dim / max_dim) / 4))) + + mask = np.zeros((max_dim, max_dim), dtype=np.float32) + + c = self.rng.uniform(low=1.1, high=1.3, size=1).item() + + for square_id in range(num_nested_squares): + + ordered_indices = self.get_square_ordered_idxs( + square_side_size=max_dim, + square_id=square_id, + ) + + # J: size of the square, J=2,…,N, i.e., the number of points along one side of the square + J = 2 * (num_nested_squares - square_id) + # K: total number of points along the perimeter of the square K=4·J-4; + K = 4 * (J - 1) + + for m in range(M): + i = np.floor(np.mod(m / GOLDEN_RATIO, 1) * K) + indices_idx = int(np.mod((i + np.ceil(J ** c) - 1), K)) + + mask[ordered_indices[indices_idx]] = 1.0 + + pad = ((shape[0] % 2, 0), (shape[1] % 2, 0)) + + mask = np.pad(mask, pad) + mask = T.center_crop(torch.from_numpy(mask.astype(bool)), shape) + + return mask + + @staticmethod + def circular_centered_mask(mask, eps=0.1): + shape = mask.shape + center = np.asarray(shape) // 2 + Y, X = np.ogrid[: shape[0], : shape[1]] + Y, X = torch.tensor(Y), torch.tensor(X) + radius = 1 + + # Finds the maximum (unmasked) disk in mask given a tolerance. + while True: + # Creates a disk with R=radius and finds intersection with mask + disk = (Y - center[0]) ** 2 + (X - center[1]) ** 2 <= radius ** 2 + intersection = disk & mask + ratio = disk.sum() / intersection.sum() + if ratio > 1.0 + eps: + return intersection + radius += eps + + def mask_func(self, shape, return_acs=False, seed=None): + + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_rows = shape[-3] + num_cols = shape[-2] + acceleration = self.choose_acceleration()[1] + + if self.subsampling_scheme == "circus-radial": + mask = self.circus_radial_mask( + shape=(num_rows, num_cols), + acceleration=acceleration, + ) + elif self.subsampling_scheme == "circus-spiral": + mask = self.circus_spiral_mask( + shape=(num_rows, num_cols), + acceleration=acceleration, + ) + + if return_acs: + return self.circular_centered_mask(mask).unsqueeze(0).unsqueeze(-1) + + return mask.unsqueeze(0).unsqueeze(-1) + + +class RadialMaskFunc(CIRCUSMaskFunc): + """ + Computes radial masks for Cartesian data. + + """ + + def __init__( + self, + accelerations, + **kwargs, + ): + super().__init__( + accelerations=accelerations, + subsampling_scheme=CIRCUSSamplingMode.circus_radial, + **kwargs, + ) + + +class SpiralMaskFunc(CIRCUSMaskFunc): + """ + Computes spiral masks for Cartesian data. + + """ + + def __init__( + self, + accelerations, + **kwargs, + ): + super().__init__( + accelerations=accelerations, + subsampling_scheme=CIRCUSSamplingMode.circus_spiral, + **kwargs, + ) + + class DictionaryMaskFunc(BaseMaskFunc): def __init__(self, data_dictionary, **kwargs): # noqa super().__init__(accelerations=None) diff --git a/direct/common/subsample_config.py b/direct/common/subsample_config.py index 9c7a7669..4585cacc 100644 --- a/direct/common/subsample_config.py +++ b/direct/common/subsample_config.py @@ -11,10 +11,10 @@ @dataclass class MaskingConfig(BaseConfig): name: str = MISSING - accelerations: Tuple[int, ...] = (4,) # Ideally Union[float, int]. - center_fractions: Optional[Tuple[float, ...]] = (0.08,) # Ideally Optional[Tuple[float, ...]] + accelerations: Tuple[int, ...] = (5,) # Ideally Union[float, int]. + center_fractions: Optional[Tuple[float, ...]] = (0.1,) # Ideally Optional[Tuple[float, ...]] uniform_range: bool = False image_center_crop: bool = False - val_accelerations: Tuple[int, ...] = (4, 8) - val_center_fractions: Optional[Tuple[float, ...]] = (0.08, 0.04) + val_accelerations: Tuple[int, ...] = (5, 10) + val_center_fractions: Optional[Tuple[float, ...]] = (0.1, 0.05) diff --git a/direct/common/tests/test_subsample.py b/direct/common/tests/test_subsample.py index 17b6d3a6..1a432d2a 100644 --- a/direct/common/tests/test_subsample.py +++ b/direct/common/tests/test_subsample.py @@ -1,15 +1,17 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors +"""Tests for the direct.common.subsample module""" + # Code and comments can be shared with code of FastMRI under the same MIT license: # https://github.com/facebookresearch/fastMRI/ -# The code can have been adjusted to our needs. +# The code has been adjusted to our needs. import numpy as np import pytest import torch -from direct.common.subsample import FastMRIRandomMaskFunc +from direct.common.subsample import FastMRIRandomMaskFunc, RadialMaskFunc, SpiralMaskFunc @pytest.mark.parametrize( @@ -53,3 +55,121 @@ def test_fastmri_random_mask_low_freqs(center_fracs, accelerations, batch_size, if np.all(mask[pad : pad + num_low_freqs].numpy() == 1): num_low_freqs_matched = True assert num_low_freqs_matched + + +@pytest.mark.parametrize( + "shape, center_fractions, accelerations", + [ + ([4, 32, 32, 2], [0.08], [4]), + ([2, 64, 64, 2], [0.04, 0.08], [8, 4]), + ], +) +def test_apply_mask_fastmri(shape, center_fractions, accelerations): + mask_func = FastMRIRandomMaskFunc( + center_fractions=center_fractions, + accelerations=accelerations, + uniform_range=False, + ) + mask = mask_func(shape[1:], seed=123) + acs_mask = mask_func(shape[1:], seed=123, return_acs=True) + expected_mask_shape = (1, shape[1], shape[2], 1) + + assert mask.max() == 1 + assert mask.min() == 0 + assert mask.shape == expected_mask_shape + assert np.allclose(mask & acs_mask, acs_mask) + + +@pytest.mark.parametrize( + "shape, center_fractions, accelerations", + [ + ([4, 32, 32, 2], [0.08], [4]), + ([2, 64, 64, 2], [0.04, 0.08], [8, 4]), + ], +) +def test_same_across_volumes_mask_fastmri(shape, center_fractions, accelerations): + mask_func = FastMRIRandomMaskFunc( + center_fractions=center_fractions, + accelerations=accelerations, + uniform_range=False, + ) + num_slices = shape[0] + masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] + + assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(num_slices - 1)) + + +@pytest.mark.parametrize( + "shape, accelerations", + [ + ([4, 32, 32, 2], [4]), + ([2, 64, 64, 2], [8, 4]), + ], +) +def test_apply_mask_radial(shape, accelerations): + mask_func = RadialMaskFunc( + accelerations=accelerations, + ) + mask = mask_func(shape[1:], seed=123) + acs_mask = mask_func(shape[1:], seed=123, return_acs=True) + expected_mask_shape = (1, shape[1], shape[2], 1) + + assert mask.max() == 1 + assert mask.min() == 0 + assert mask.shape == expected_mask_shape + assert np.allclose(mask & acs_mask, acs_mask) + + +@pytest.mark.parametrize( + "shape, accelerations", + [ + ([4, 32, 32, 2], [4]), + ([2, 64, 64, 2], [8, 4]), + ], +) +def test_same_across_volumes_mask_radial(shape, accelerations): + mask_func = RadialMaskFunc( + accelerations=accelerations, + ) + num_slices = shape[0] + masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] + + assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(num_slices - 1)) + + +@pytest.mark.parametrize( + "shape, accelerations", + [ + ([4, 32, 32, 2], [4]), + ([2, 64, 64, 2], [8, 4]), + ], +) +def test_apply_mask_spiral(shape, accelerations): + mask_func = SpiralMaskFunc( + accelerations=accelerations, + ) + mask = mask_func(shape[1:], seed=123) + acs_mask = mask_func(shape[1:], seed=123, return_acs=True) + expected_mask_shape = (1, shape[1], shape[2], 1) + + assert mask.max() == 1 + assert mask.min() == 0 + assert mask.shape == expected_mask_shape + assert np.allclose(mask & acs_mask, acs_mask) + + +@pytest.mark.parametrize( + "shape, accelerations", + [ + ([4, 32, 32, 2], [4]), + ([2, 64, 64, 2], [8, 4]), + ], +) +def test_same_across_volumes_mask_spiral(shape, accelerations): + mask_func = SpiralMaskFunc( + accelerations=accelerations, + ) + num_slices = shape[0] + masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] + + assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(num_slices - 1)) diff --git a/direct/data/datasets.py b/direct/data/datasets.py index 0bf0a668..431c514d 100644 --- a/direct/data/datasets.py +++ b/direct/data/datasets.py @@ -101,7 +101,7 @@ def __init__( if self.text_description: self.logger.info(f"Dataset description: {self.text_description}.") - self.generator: Callable = FakeMRIData( + self.fake_data: Callable = FakeMRIData( ndim=len(self.spatial_shape), blobs_n_samples=kwargs.get("blobs_n_samples", None), blobs_cluster_std=kwargs.get("blobs_cluster_std", None), @@ -164,7 +164,7 @@ def __len__(self): def __getitem__(self, idx: int) -> Dict[str, Any]: filename, slice_no, sample_seed = self.data[idx] - sample = self.generator( + sample = self.fake_data( sample_size=1, num_coils=self.num_coils, spatial_shape=self.spatial_shape, @@ -402,7 +402,6 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: kspace[:, int(np.ceil(num_z * self.sampling_rate_slice_encode)) :, :] = 0.0 + 0.0 * 1j # Downstream code expects the coils to be at the first axis. - # TODO: When named tensor support is more solid, this could be circumvented. sample["kspace"] = np.ascontiguousarray(kspace.transpose(2, 0, 1)) if self.transform: diff --git a/direct/data/fake.py b/direct/data/fake.py index f9ce2fcc..0a5231d2 100644 --- a/direct/data/fake.py +++ b/direct/data/fake.py @@ -33,7 +33,7 @@ def __init__( """ if ndim not in [2, 3]: - raise NotImplementedError(f"Currently FakeMRIDataGenerator is not implemented for {ndim}D data.") + raise NotImplementedError(f"Currently FakeMRIData is not implemented for {ndim}D data.") self.ndim = ndim self.blobs_n_samples = blobs_n_samples diff --git a/direct/data/tests/test_fake.py b/direct/data/tests/test_fake.py index 92823f00..c645e9f9 100644 --- a/direct/data/tests/test_fake.py +++ b/direct/data/tests/test_fake.py @@ -25,11 +25,11 @@ "spatial_shape", [(32, 32), (10, 32, 32), [10, 32, 32]], ) -def test_generator(size, num_coils, spatial_shape): +def test_fake(size, num_coils, spatial_shape): - generator = FakeMRIData(ndim=len(spatial_shape)) + fake_data = FakeMRIData(ndim=len(spatial_shape)) - samples = generator(size, num_coils, spatial_shape) + samples = fake_data(size, num_coils, spatial_shape) keys = ["kspace", "reconstruction_rss", "attrs"] assert all(_ in samples[0].keys() for _ in keys) diff --git a/direct/data/tests/test_transforms.py b/direct/data/tests/test_transforms.py index bbf1c86b..fc298015 100644 --- a/direct/data/tests/test_transforms.py +++ b/direct/data/tests/test_transforms.py @@ -1,48 +1,23 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -"""Tests for the direct.data.transforms module""" -# Some of this code is written by Facebook for the FastMRI challenge and is licensed under the MIT license. -# The code has been heavily edited, but some parts could still be recognized. +"""Tests for the direct.data.transforms module""" import numpy as np import pytest import torch -from direct.common.subsample import FastMRIRandomMaskFunc from direct.data import transforms from direct.data.transforms import tensor_to_complex_numpy def create_input(shape): - # data = np.arange(np.product(shape)).reshape(shape).copy() data = np.random.randn(*shape).copy() data = torch.from_numpy(data).float() return data -@pytest.mark.parametrize( - "shape, center_fractions, accelerations", - [ - ([4, 32, 32, 2], [0.08], [4]), - ([2, 64, 64, 2], [0.04, 0.08], [8, 4]), - ], -) -def test_apply_mask_fastmri(shape, center_fractions, accelerations): - mask_func = FastMRIRandomMaskFunc( - center_fractions=center_fractions, - accelerations=accelerations, - uniform_range=False, - ) - mask = mask_func(shape[1:], seed=123) - expected_mask_shape = (1, shape[1], shape[2], 1) - - assert mask.max() == 1 - assert mask.min() == 0 - assert mask.shape == expected_mask_shape - - @pytest.mark.parametrize( "shape, dim", [ diff --git a/direct/environment.py b/direct/environment.py index 146d090e..41390cad 100644 --- a/direct/environment.py +++ b/direct/environment.py @@ -325,13 +325,16 @@ def setup_testing_environment( device, machine_rank, mixed_precision, + cfg_file=None, debug=False, ): - - cfg_filename = base_directory / run_name / "config.yaml" + if cfg_file is None: + cfg_filename = base_directory / run_name / "config.yaml" + else: + cfg_filename = cfg_file if not cfg_filename.exists(): - raise OSError(f"Config file {cfg_filename} does not exist.") + raise FileNotFoundError(f"Config file {cfg_filename} does not exist.") env = setup_common_environment( run_name, @@ -356,10 +359,13 @@ def setup_inference_environment( device, machine_rank, mixed_precision, + cfg_file=None, debug=False, ): - env = setup_testing_environment(run_name, base_directory, device, machine_rank, mixed_precision, debug=debug) + env = setup_testing_environment( + run_name, base_directory, device, machine_rank, mixed_precision, cfg_file, debug=debug + ) out_env = namedtuple( "environment", diff --git a/direct/inference.py b/direct/inference.py index 7476b6a5..a0cab418 100644 --- a/direct/inference.py +++ b/direct/inference.py @@ -28,6 +28,7 @@ def setup_inference_save_to_h5( device, num_workers: int, machine_rank: int, + cfg_file=None, process_per_chunk: Optional[int] = None, volume_processing_func: Callable = None, mixed_precision: bool = False, @@ -48,6 +49,7 @@ def setup_inference_save_to_h5( device : num_workers : machine_rank : + cfg_file : process_per_chunk : volume_processing_func : mixed_precision : @@ -57,7 +59,9 @@ def setup_inference_save_to_h5( ------- None """ - env = setup_inference_environment(run_name, base_directory, device, machine_rank, mixed_precision, debug=debug) + env = setup_inference_environment( + run_name, base_directory, device, machine_rank, mixed_precision, cfg_file, debug=debug + ) dataset_cfg, transforms = get_inference_settings(env) diff --git a/direct/nn/didn/didn.py b/direct/nn/didn/didn.py index 91d5947e..0642fff3 100644 --- a/direct/nn/didn/didn.py +++ b/direct/nn/didn/didn.py @@ -8,8 +8,15 @@ class Subpixel(nn.Module): """ - Subpixel convolution layer for up-scaling of low resolution features at super-resolution as implemented - in https://ieeexplore.ieee.org/document/9025411. + Subpixel convolution layer for up-scaling of low resolution features at super-resolution as implemented in [1]_. + + References + ---------- + + .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” + 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops + (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + """ def __init__(self, in_channels, out_channels, upscale_factor, kernel_size, padding=0): @@ -25,7 +32,15 @@ def forward(self, x): class ReconBlock(nn.Module): """ - Reconstruction Block of DIDN model as implemented in https://ieeexplore.ieee.org/document/9025411. + Reconstruction Block of DIDN model as implemented in [1]_. + + References + ---------- + + .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” + 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops + (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + """ def __init__(self, in_channels, num_convs): @@ -55,7 +70,15 @@ def forward(self, input): class DUB(nn.Module): """ - Down-up block (DUB) for DIDN model as implemented in https://ieeexplore.ieee.org/document/9025411. + Down-up block (DUB) for DIDN model as implemented in [1]_. + + References + ---------- + + .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” + 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops + (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + """ def __init__( @@ -151,8 +174,15 @@ def forward(self, x): class DIDN(nn.Module): """ - Deep Iterative Down-up convolutional Neural network (DIDN) implementation as in - https://ieeexplore.ieee.org/document/9025411. + Deep Iterative Down-up convolutional Neural network (DIDN) implementation as in [1]_. + + References + ---------- + + .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” + 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops + (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + """ def __init__( diff --git a/direct/nn/kikinet/kikinet.py b/direct/nn/kikinet/kikinet.py index 621a8f35..6198dbcc 100644 --- a/direct/nn/kikinet/kikinet.py +++ b/direct/nn/kikinet/kikinet.py @@ -16,9 +16,15 @@ class KIKINet(nn.Module): """ - Based on KIKINet implementation as in "KIKI-net: cross-domain convolutional neural networks for - reconstructing undersampled magnetic resonance images" by Taejoon Eo et all. Modified to work with - multicoil kspace data. + Based on KIKINet implementation [1]_. Modified to work with multicoil kspace data. + + References + ---------- + + .. [1] Eo, Taejoon, et al. “KIKI-Net: Cross-Domain Convolutional Neural Networks for Reconstructing + Undersampled Magnetic Resonance Images.” Magnetic Resonance in Medicine, vol. 80, no. 5, Nov. 2018, + pp. 2188–201. PubMed, https://doi.org/10.1002/mrm.27201. + """ def __init__( diff --git a/direct/nn/lpd/lpd.py b/direct/nn/lpd/lpd.py index 46d5f2d2..b4d33a25 100644 --- a/direct/nn/lpd/lpd.py +++ b/direct/nn/lpd/lpd.py @@ -85,7 +85,14 @@ def forward(self, f, backward_h): class LPDNet(nn.Module): """ - Learned Primal Dual network implementation as in https://arxiv.org/abs/1707.06474. + Learned Primal Dual network implementation inspired by [1]_. + + References + ---------- + + .. [1] Adler, Jonas, and Ozan Öktem. “Learned Primal-Dual Reconstruction.” IEEE Transactions on Medical Imaging, + vol. 37, no. 6, June 2018, pp. 1322–32. arXiv.org, https://doi.org/10.1109/TMI.2018.2799231. + """ def __init__( diff --git a/direct/nn/mwcnn/mwcnn.py b/direct/nn/mwcnn/mwcnn.py index d6a27c73..cd800a93 100644 --- a/direct/nn/mwcnn/mwcnn.py +++ b/direct/nn/mwcnn/mwcnn.py @@ -11,7 +11,14 @@ class DWT(nn.Module): """ - 2D Discrete Wavelet Transform as implemented in https://arxiv.org/abs/1805.07071. + 2D Discrete Wavelet Transform as implemented in [1]_. + + References + ---------- + + .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. + arXiv.org, http://arxiv.org/abs/1805.07071. + """ def __init__(self): @@ -35,7 +42,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class IWT(nn.Module): """ - 2D Inverse Wavelet Transform as implemented in https://arxiv.org/abs/1805.07071. + 2D Inverse Wavelet Transform as implemented in [1]_. + + References + ---------- + + .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. + arXiv.org, http://arxiv.org/abs/1805.07071. + """ def __init__(self): @@ -65,7 +79,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ConvBlock(nn.Module): """ - Convolution Block for MWCNN as implemented in https://arxiv.org/abs/1805.07071. + Convolution Block for MWCNN as implemented in [1]_. + + References + ---------- + + .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. + arXiv.org, http://arxiv.org/abs/1805.07071. + """ def __init__( @@ -104,7 +125,14 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class DilatedConvBlock(nn.Module): """ - Double dilated Convolution Block fpr MWCNN as implemented in https://arxiv.org/abs/1805.07071. + Double dilated Convolution Block fpr MWCNN as implemented in [1]_. + + References + ---------- + + .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. + arXiv.org, http://arxiv.org/abs/1805.07071. + """ def __init__( @@ -159,7 +187,14 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class MWCNN(nn.Module): """ - Multi-level Wavelet CNN (MWCNN) implementation as implemented in https://arxiv.org/abs/1805.07071. + Multi-level Wavelet CNN (MWCNN) implementation as implemented in [1]_. + + References + ---------- + + .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. + arXiv.org, http://arxiv.org/abs/1805.07071. + """ def __init__( diff --git a/direct/nn/recurrentvarnet/recurrentvarnet.py b/direct/nn/recurrentvarnet/recurrentvarnet.py index 26278bb1..2c4eaead 100644 --- a/direct/nn/recurrentvarnet/recurrentvarnet.py +++ b/direct/nn/recurrentvarnet/recurrentvarnet.py @@ -14,9 +14,16 @@ class RecurrentInit(nn.Module): """ - Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in - https://arxiv.org/abs/2111.09639. The RSI module learns to initialize the recurrent hidden state h_0, - input of the first RecurrentVarNet Block of the RecurrentVarNet. + Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in [1]_. + The RSI module learns to initialize the recurrent hidden state h_0, input of the first RecurrentVarNet + Block of the RecurrentVarNet. + + References + ---------- + + .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver + Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. + arXiv.org, http://arxiv.org/abs/2111.09639. """ @@ -85,7 +92,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class RecurrentVarNet(nn.Module): """ - Recurrent Variational Network implementation as presented in https://arxiv.org/abs/2111.09639. + Recurrent Variational Network implementation as presented in [1]_. + + References + ---------- + + .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver + Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. + arXiv.org, http://arxiv.org/abs/2111.09639. + """ def __init__( @@ -266,7 +281,15 @@ def forward( class RecurrentVarNetBlock(nn.Module): """ - Recurrent Variational Network Block as presented in https://arxiv.org/abs/2111.09639. + Recurrent Variational Network Block as presented in [1]_. + + References + ---------- + + .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver + Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. + arXiv.org, http://arxiv.org/abs/2111.09639. + """ def __init__( diff --git a/direct/nn/rim/rim.py b/direct/nn/rim/rim.py index 0167c6b9..96a1b2d1 100644 --- a/direct/nn/rim/rim.py +++ b/direct/nn/rim/rim.py @@ -111,7 +111,13 @@ class RIMInit(nn.Module): Learned initializer for RIM, based on multi-scale context aggregation with dilated convolutions, that replaces zero initializer for the RIM hidden vector. - Inspired by "Multi-Scale Context Aggregation by Dilated Convolutions" (https://arxiv.org/abs/1511.07122) + Inspired by [1]_. + + References + ---------- + + .. [1] Yu, Fisher, and Vladlen Koltun. “Multi-Scale Context Aggregation by Dilated Convolutions.” + ArXiv:1511.07122 [Cs], Apr. 2016. arXiv.org, http://arxiv.org/abs/1511.07122. """ def __init__( @@ -179,7 +185,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class RIM(nn.Module): """ - Recurrent Inference Machine Module as in https://arxiv.org/abs/1706.04008. + Recurrent Inference Machine Module as in [1]_. + + References + ---------- + + .. [1] Putzky, Patrick, and Max Welling. “Recurrent Inference Machines for Solving Inverse Problems.” + ArXiv:1706.04008 [Cs], June 2017. arXiv.org, http://arxiv.org/abs/1706.04008. + """ def __init__( diff --git a/direct/nn/unet/unet_2d.py b/direct/nn/unet/unet_2d.py index 18d27c88..b54b6603 100644 --- a/direct/nn/unet/unet_2d.py +++ b/direct/nn/unet/unet_2d.py @@ -112,12 +112,14 @@ def __repr__(self): class UnetModel2d(nn.Module): """ - PyTorch implementation of a U-Net model. + PyTorch implementation of a U-Net model based on [1]_. - This is based on: - Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks - for biomedical image segmentation. In International Conference on Medical image - computing and computer-assisted intervention, pages 234–241. Springer, 2015. + References + ---------- + + .. [1] Ronneberger, Olaf, et al. “U-Net: Convolutional Networks for Biomedical Image Segmentation.” + Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015, edited by Nassir Navab et al., + Springer International Publishing, 2015, pp. 234–41. Springer Link, https://doi.org/10.1007/978-3-319-24574-4_28. """ def __init__( diff --git a/direct/nn/varnet/varnet.py b/direct/nn/varnet/varnet.py index 00ade925..0956fc6a 100644 --- a/direct/nn/varnet/varnet.py +++ b/direct/nn/varnet/varnet.py @@ -12,7 +12,14 @@ class EndToEndVarNet(nn.Module): """ - End-to-End Variational Network as in https://arxiv.org/abs/2004.06688. + End-to-End Variational Network based on [1]_. + + References + ---------- + + .. [1] Sriram, Anuroop, et al. “End-to-End Variational Networks for Accelerated MRI Reconstruction.” + ArXiv:2004.06688 [Cs, Eess], Apr. 2020. arXiv.org, http://arxiv.org/abs/2004.06688. + """ def __init__( diff --git a/direct/nn/xpdnet/xpdnet.py b/direct/nn/xpdnet/xpdnet.py index 4f2f8c74..e9221dd5 100644 --- a/direct/nn/xpdnet/xpdnet.py +++ b/direct/nn/xpdnet/xpdnet.py @@ -14,7 +14,14 @@ class XPDNet(CrossDomainNetwork): """ - XPDNet as implemented in https://arxiv.org/abs/2010.07290. + XPDNet as implemented in [1]_. + + References + ---------- + + .. [1] Ramzi, Zaccharie, et al. “XPDNet for MRI Reconstruction: An Application to the 2020 FastMRI Challenge.” + ArXiv:2010.07290 [Physics, Stat], July 2021. arXiv.org, http://arxiv.org/abs/2010.07290. + """ def __init__( diff --git a/projects/calgary_campinas/predict_test.py b/projects/calgary_campinas/predict_test.py index a02fef0a..b8cb392a 100644 --- a/projects/calgary_campinas/predict_test.py +++ b/projects/calgary_campinas/predict_test.py @@ -77,9 +77,6 @@ def _get_transforms(masks_dict, env): torch.set_num_threads(1) os.environ["OMP_NUM_THREADS"] = "1" - # Remove warnings from named tensors being experimental - os.environ["PYTHONWARNINGS"] = "ignore" - epilog = f""" Examples: Run on single machine: diff --git a/projects/calgary_campinas/predict_val.py b/projects/calgary_campinas/predict_val.py index e812873a..0acb1d70 100644 --- a/projects/calgary_campinas/predict_val.py +++ b/projects/calgary_campinas/predict_val.py @@ -14,7 +14,7 @@ from direct.inference import build_inference_transforms, setup_inference_save_to_h5 from direct.utils import set_all_seeds -from .utils import volume_post_processing_func as calgary_campinas_post_processing_func +from utils import volume_post_processing_func as calgary_campinas_post_processing_func logger = logging.getLogger(__name__) @@ -32,9 +32,6 @@ def _get_transforms(validation_index, env): torch.set_num_threads(1) os.environ["OMP_NUM_THREADS"] = "1" - # Remove warnings from named tensors being experimental - os.environ["PYTHONWARNINGS"] = "ignore" - epilog = f""" Examples: Run on single machine: diff --git a/projects/calgary_campinas/compute_metrics.py b/projects/spie_radial_subsampling/compute_metrics.py similarity index 86% rename from projects/calgary_campinas/compute_metrics.py rename to projects/spie_radial_subsampling/compute_metrics.py index a35afdb9..73f571e9 100644 --- a/projects/calgary_campinas/compute_metrics.py +++ b/projects/spie_radial_subsampling/compute_metrics.py @@ -4,14 +4,18 @@ import argparse import glob import json +import logging import os import pathlib +import sys import h5py from direct.data.transforms import * from direct.functionals.challenges import * +logger = logging.getLogger(__name__) + def _get_filenames_from_lists(path_to_lst): names = [] @@ -31,11 +35,6 @@ def _get_file_from_h5(pred_filename, target_filename): target_kspace = np.array(target["kspace"][50:-50]) target_kspace = to_tensor(target_kspace[..., ::2] + 1j * target_kspace[..., 1::2]) - # TODO(gy): Needed? - # sampling_rate_slice_encode = 0.85 - # num_z = target_kspace.shape[1] - # target_kspace[:, int(np.ceil(num_z * sampling_rate_slice_encode)):, :] = 0.0 + 0.0 * 1j - target_rec = _get_reconstruction(target_kspace.permute(0, 3, 1, 2, 4)) return pred_rec, target_rec @@ -81,7 +80,6 @@ def _get_metrics(pred_rec, target_rec): args = parser.parse_args() filenames = _get_filenames_from_lists(args.filenames_filter) - metrics = dict() for filename in filenames: @@ -90,11 +88,16 @@ def _get_metrics(pred_rec, target_rec): target_filename = pathlib.Path(pathlib.PurePath(args.target_data_root, filename)) if pred_filename.exists() and target_filename.exists(): - + logger.info(f"Computing metrics for {filename}...") pred_rec, target_rec = _get_file_from_h5(pred_filename, target_filename) metrics[filename.name] = _get_metrics(pred_rec, target_rec) + else: + logger.info(f"Filename {filename} not found in both, target and predicted directories. Skipping...") if len(metrics) > 0: - with open(args.name + ".json", "w") as f: + logger.info(f"Saving metrics for {len(metrics)} filenames.") + with open(args.name + ".json", "a") as f: f.write(json.dumps(metrics, indent=4, sort_keys=True)) + else: + logger.info("No metrics were computed.") diff --git a/projects/spie_radial_subsampling/configs/base_radial.yaml b/projects/spie_radial_subsampling/configs/base_radial.yaml new file mode 100644 index 00000000..0f82c288 --- /dev/null +++ b/projects/spie_radial_subsampling/configs/base_radial.yaml @@ -0,0 +1,117 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: Radial + accelerations: [5, 10] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: Radial + accelerations: [5, 10] + crop_outer_slices: true + batch_size: 4 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0001 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 1000000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 # With batch size 4 and 4 GPUs this is about 7300 iterations, or ~1 epoch. + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: Radial + accelerations: [5] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: Radial + accelerations: [10] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: rim.rim.RIM + hidden_channels: 128 + image_initialization: sense # This uses the computed sensitivity map to create a zero-filled reconstruction + length: 16 + depth: 2 + steps: 1 + no_parameter_sharing: false + instance_norm: false + dense_connect: false + replication_padding: true +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 8 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + crop_outer_slices: true + text_description: inference + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace diff --git a/projects/spie_radial_subsampling/configs/base_rectilinear.yaml b/projects/spie_radial_subsampling/configs/base_rectilinear.yaml new file mode 100644 index 00000000..eb4f861a --- /dev/null +++ b/projects/spie_radial_subsampling/configs/base_rectilinear.yaml @@ -0,0 +1,121 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: FastMRIRandom + accelerations: [5, 10] + center_fractions: [0.1, 0.05] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: FastMRIRandom + accelerations: [5, 10] + center_fractions: [0.1, 0.05] + crop_outer_slices: true + batch_size: 4 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0001 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 1000000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 # With batch size 4 and 4 GPUs this is about 7300 iterations, or ~1 epoch. + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: FastMRIRandom + accelerations: [5] + center_fractions: [0.1] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: FastMRIRandom + accelerations: [10] + center_fractions: [0.05] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: rim.rim.RIM + hidden_channels: 128 + image_initialization: sense # This uses the computed sensitivity map to create a zero-filled reconstruction + length: 16 + depth: 2 + steps: 1 + no_parameter_sharing: false + instance_norm: false + dense_connect: false + replication_padding: true +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 8 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + crop_outer_slices: true + text_description: inference + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace diff --git a/projects/spie_radial_subsampling/configs/inference/10x/base_radial.yaml b/projects/spie_radial_subsampling/configs/inference/10x/base_radial.yaml new file mode 100644 index 00000000..63d99514 --- /dev/null +++ b/projects/spie_radial_subsampling/configs/inference/10x/base_radial.yaml @@ -0,0 +1,120 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: Radial + accelerations: [5, 10] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: Radial + accelerations: [5, 10] + crop_outer_slices: true + batch_size: 4 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0001 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 1000000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 # With batch size 4 and 4 GPUs this is about 7300 iterations, or ~1 epoch. + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: Radial + accelerations: [5] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: Radial + accelerations: [10] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: rim.rim.RIM + hidden_channels: 128 + image_initialization: sense # This uses the computed sensitivity map to create a zero-filled reconstruction + length: 16 + depth: 2 + steps: 1 + no_parameter_sharing: false + instance_norm: false + dense_connect: false + replication_padding: true +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 8 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: Radial + accelerations: [10] + crop_outer_slices: true + text_description: inference-10x diff --git a/projects/spie_radial_subsampling/configs/inference/10x/base_rectilinear.yaml b/projects/spie_radial_subsampling/configs/inference/10x/base_rectilinear.yaml new file mode 100644 index 00000000..9cb02852 --- /dev/null +++ b/projects/spie_radial_subsampling/configs/inference/10x/base_rectilinear.yaml @@ -0,0 +1,125 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: FastMRIRandom + accelerations: [5, 10] + center_fractions: [0.1, 0.05] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: FastMRIRandom + accelerations: [5, 10] + center_fractions: [0.1, 0.05] + crop_outer_slices: true + batch_size: 4 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0001 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 1000000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 # With batch size 4 and 4 GPUs this is about 7300 iterations, or ~1 epoch. + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: FastMRIRandom + accelerations: [5] + center_fractions: [0.1] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: FastMRIRandom + accelerations: [10] + center_fractions: [0.05] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: rim.rim.RIM + hidden_channels: 128 + image_initialization: sense # This uses the computed sensitivity map to create a zero-filled reconstruction + length: 16 + depth: 2 + steps: 1 + no_parameter_sharing: false + instance_norm: false + dense_connect: false + replication_padding: true +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 8 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: FastMRIRandom + accelerations: [10] + center_fractions: [0.05] + crop_outer_slices: true + text_description: inference-10x diff --git a/projects/spie_radial_subsampling/configs/inference/5x/base_radial.yaml b/projects/spie_radial_subsampling/configs/inference/5x/base_radial.yaml new file mode 100644 index 00000000..3f31e3c5 --- /dev/null +++ b/projects/spie_radial_subsampling/configs/inference/5x/base_radial.yaml @@ -0,0 +1,120 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: Radial + accelerations: [5, 10] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: Radial + accelerations: [5, 10] + crop_outer_slices: true + batch_size: 4 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0001 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 1000000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 # With batch size 4 and 4 GPUs this is about 7300 iterations, or ~1 epoch. + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: Radial + accelerations: [5] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: Radial + accelerations: [10] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: rim.rim.RIM + hidden_channels: 128 + image_initialization: sense # This uses the computed sensitivity map to create a zero-filled reconstruction + length: 16 + depth: 2 + steps: 1 + no_parameter_sharing: false + instance_norm: false + dense_connect: false + replication_padding: true +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 8 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: Radial + accelerations: [5] + crop_outer_slices: true + text_description: inference-5x diff --git a/projects/spie_radial_subsampling/configs/inference/5x/base_rectilinear.yaml b/projects/spie_radial_subsampling/configs/inference/5x/base_rectilinear.yaml new file mode 100644 index 00000000..180e835b --- /dev/null +++ b/projects/spie_radial_subsampling/configs/inference/5x/base_rectilinear.yaml @@ -0,0 +1,125 @@ +physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) +training: + datasets: + # Two datasets, only difference is the shape, so the data can be collated for larger batches + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x170_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: FastMRIRandom + accelerations: [5, 10] + center_fractions: [0.1, 0.05] + crop_outer_slices: true + - name: CalgaryCampinas + lists: + - ../lists/train/12x218x180_train.lst + transforms: + crop: null + estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS + scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum + image_center_crop: false + masking: + name: FastMRIRandom + accelerations: [5, 10] + center_fractions: [0.1, 0.05] + crop_outer_slices: true + batch_size: 4 # This is the batch size per GPU! + optimizer: Adam + lr: 0.0001 + weight_decay: 0.0 + lr_step_size: 50000 + lr_gamma: 0.2 + lr_warmup_iter: 1000 + num_iterations: 1000000 + gradient_steps: 1 + gradient_clipping: 0.0 + gradient_debug: false + checkpointer: + checkpoint_steps: 500 + validation_steps: 500 # With batch size 4 and 4 GPUs this is about 7300 iterations, or ~1 epoch. + loss: + crop: null + losses: + - function: l1_loss + multiplier: 1.0 + - function: ssim_loss + multiplier: 1.0 +validation: + datasets: + # Twice the same dataset but a different acceleration factor + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: FastMRIRandom + accelerations: [5] + center_fractions: [0.1] + crop_outer_slices: true + text_description: 5x # Description for logging + - name: CalgaryCampinas + lists: + - ../lists/val/12x218x170_val.lst + - ../lists/val/12x218x180_val.lst + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: FastMRIRandom + accelerations: [10] + center_fractions: [0.05] + crop_outer_slices: true + text_description: 10x # Description for logging + crop: null # This sets the cropping for the DoIterationOutput + metrics: # These are obtained from direct.functionals + - calgary_campinas_psnr + - calgary_campinas_ssim + - calgary_campinas_vif +model: + model_name: rim.rim.RIM + hidden_channels: 128 + image_initialization: sense # This uses the computed sensitivity map to create a zero-filled reconstruction + length: 16 + depth: 2 + steps: 1 + no_parameter_sharing: false + instance_norm: false + dense_connect: false + replication_padding: true +additional_models: + sensitivity_model: + model_name: unet.unet_2d.UnetModel2d + in_channels: 2 + out_channels: 2 + num_filters: 8 + num_pool_layers: 4 + dropout_probability: 0.0 +logging: + tensorboard: + num_images: 4 +inference: + batch_size: 8 + dataset: + name: CalgaryCampinas + transforms: + crop: null + estimate_sensitivity_maps: true + scaling_key: masked_kspace + masking: + name: FastMRIRandom + accelerations: [5] + center_fractions: [0.1] + crop_outer_slices: true + text_description: inference-5x diff --git a/projects/spie_radial_subsampling/lists/test/12x218x170_test.lst b/projects/spie_radial_subsampling/lists/test/12x218x170_test.lst new file mode 100644 index 00000000..1d176ec5 --- /dev/null +++ b/projects/spie_radial_subsampling/lists/test/12x218x170_test.lst @@ -0,0 +1,11 @@ +e14437s5_P49152.7.h5 +e15862s13_P40960.7.h5 +e14477s5_P34816.7.h5 +e14351s3_P29184.7.h5 +e15195s3_P39424.7.h5 +e14140s3_P52224.7.h5 +e14692s5_P14848.7.h5 +e14264s3_P08192.7.h5 +e15197s3_P53760.7.h5 +e15183s3_P52224.7.h5 +e16673s13_P31744.7.h5 diff --git a/projects/spie_radial_subsampling/lists/test/12x218x180_test.lst b/projects/spie_radial_subsampling/lists/test/12x218x180_test.lst new file mode 100644 index 00000000..097ccd5e --- /dev/null +++ b/projects/spie_radial_subsampling/lists/test/12x218x180_test.lst @@ -0,0 +1,2 @@ +e16882s4_P38912.7.h5 +e15521s3_P33280.7.h5 diff --git a/projects/spie_radial_subsampling/lists/test/test.lst b/projects/spie_radial_subsampling/lists/test/test.lst new file mode 100644 index 00000000..4a6f9d14 --- /dev/null +++ b/projects/spie_radial_subsampling/lists/test/test.lst @@ -0,0 +1,13 @@ +e14437s5_P49152.7.h5 +e15862s13_P40960.7.h5 +e14477s5_P34816.7.h5 +e14351s3_P29184.7.h5 +e15195s3_P39424.7.h5 +e14140s3_P52224.7.h5 +e14692s5_P14848.7.h5 +e14264s3_P08192.7.h5 +e15197s3_P53760.7.h5 +e15183s3_P52224.7.h5 +e16673s13_P31744.7.h5 +e16882s4_P38912.7.h5 +e15521s3_P33280.7.h5 diff --git a/projects/spie_radial_subsampling/lists/train/12x218x170_train.lst b/projects/spie_radial_subsampling/lists/train/12x218x170_train.lst new file mode 100644 index 00000000..2f0662d0 --- /dev/null +++ b/projects/spie_radial_subsampling/lists/train/12x218x170_train.lst @@ -0,0 +1,34 @@ +e14396s3_P52224.7.h5 +e14537s3_P14336.7.h5 +e14195s3_P03584.7.h5 +e15494s3_P24064.7.h5 +e15581s3_P27136.7.h5 +e14441s5_P76800.7.h5 +e16882s14_P46080.7.h5 +e14134s3_P06656.7.h5 +e15198s3_P61952.7.h5 +e14423s3_P29696.7.h5 +e14141s3_P58880.7.h5 +e14268s3_P26112.7.h5 +e15598s3_P54784.7.h5 +e14377s6_P33280.7.h5 +e14302s3_P52224.7.h5 +e14352s3_P35840.7.h5 +e14507s5_P41472.7.h5 +e14508s5_P48128.7.h5 +e15135s4_P48640.7.h5 +e14532s5_P74752.7.h5 +e15828s13_P65536.7.h5 +e15802s13_P50176.7.h5 +e14427s3_P76288.7.h5 +e14292s3_P85504.7.h5 +e14089s3_P53248.7.h5 +e14304s3_P64000.7.h5 +e14487s5_P47616.7.h5 +e14530s5_P60416.7.h5 +e14296s4_P09216.7.h5 +e14369s5_P40960.7.h5 +e15184s3_P58880.7.h5 +e14378s5_P40448.7.h5 +e14191s3_P58368.7.h5 +e15523s3_P52224.7.h5 diff --git a/projects/spie_radial_subsampling/lists/train/12x218x180_train.lst b/projects/spie_radial_subsampling/lists/train/12x218x180_train.lst new file mode 100644 index 00000000..7fc37468 --- /dev/null +++ b/projects/spie_radial_subsampling/lists/train/12x218x180_train.lst @@ -0,0 +1,6 @@ +e15802s3_P42496.7.h5 +e15578s13_P08192.7.h5 +e15828s3_P57856.7.h5 +e16971s3_P23040.7.h5 +e16972s3_P31232.7.h5 +e15862s3_P33792.7.h5 diff --git a/projects/spie_radial_subsampling/lists/val/12x218x170_val.lst b/projects/spie_radial_subsampling/lists/val/12x218x170_val.lst new file mode 100644 index 00000000..d717b078 --- /dev/null +++ b/projects/spie_radial_subsampling/lists/val/12x218x170_val.lst @@ -0,0 +1,11 @@ +e14313s5_P37888.7.h5 +e14258s3_P76800.7.h5 +e14691s3_P06656.7.h5 +e14531s6_P68096.7.h5 +e14280s3_P44032.7.h5 +e14498s5_P60928.7.h5 +e14120s11_P66048.7.h5 +e14110s3_P59904.7.h5 +e14583s3_P21504.7.h5 +e14542s5_P52224.7.h5 +e14584s5_P30208.7.h5 diff --git a/projects/spie_radial_subsampling/lists/val/12x218x180_val.lst b/projects/spie_radial_subsampling/lists/val/12x218x180_val.lst new file mode 100644 index 00000000..8eb4f08d --- /dev/null +++ b/projects/spie_radial_subsampling/lists/val/12x218x180_val.lst @@ -0,0 +1,3 @@ +e15652s4_P45056.7.h5 +e15652s14_P51712.7.h5 +e16673s3_P24576.7.h5 diff --git a/projects/spie_radial_subsampling/plot_zoomed.py b/projects/spie_radial_subsampling/plot_zoomed.py new file mode 100644 index 00000000..3328b5fb --- /dev/null +++ b/projects/spie_radial_subsampling/plot_zoomed.py @@ -0,0 +1,74 @@ +import matplotlib.patches as patches +from mpl_toolkits.axes_grid1.inset_locator import mark_inset, zoomed_inset_axes + + +def zoom_in_rectangle(img, ax, zoom, rectangle_xy, rectangle_width, rectangle_height, **kwargs): + """ + Parameters: + ----------- + img: array-like + The image data. + ax: Axes + Axes to place the inset axes. + zoom: float + Scaling factor of the data axes. zoom > 1 will enlargen the coordinates (i.e., "zoomed in"), + while zoom < 1 will shrink the coordinates (i.e., "zoomed out"). + rectangle_xy: (float or int, float or int) + The anchor point of the rectangle to be zoomed. + rectangle_width: float or int + Rectangle to be zoomed width. + rectangle_height: float or int + Rectangle to be zoomed height. + + Other Parameters: + ----------------- + cmap: str or Colormap, default 'gray' + The Colormap instance or registered colormap name used to map scalar data to colors. + zoomed_inset_loc: int or str, default: 'upper right' + Location to place the inset axes. + zoomed_inset_lw: float or None, default 1 + Zoomed inset axes linewidth. + zoomed_inset_col: float or None, default black + Zoomed inset axes color. + mark_inset_loc1: int or str, default is 1 + First location to place line connecting box and inset axes. + mark_inset_loc2: int or str, default is 3 + Second location to place line connecting box and inset axes. + mark_inset_lw: float or None, default None + Linewidth of lines connecting box and inset axes. + mark_inset_ec: color or None + Color of lines connecting box and inset axes. + + """ + axins = zoomed_inset_axes(ax, zoom, loc=kwargs.get("zoomed_inset_loc", 1)) + + rect = patches.Rectangle(xy=rectangle_xy, width=rectangle_width, height=rectangle_height) + x1, x2 = rect.get_x(), rect.get_x() + rect.get_width() + y1, y2 = rect.get_y(), rect.get_y() + rect.get_height() + + axins.set_xlim(x1, x2) + axins.set_ylim(y1, y2) + + mark_inset( + ax, + axins, + loc1=kwargs.get("mark_inset_loc1", 1), + loc2=kwargs.get("mark_inset_loc2", 3), + lw=kwargs.get("mark_inset_lw", None), + ec=kwargs.get("mark_inset_ec", "1.0"), + ) + + axins.imshow( + img, + cmap=kwargs.get("cmap", "gray"), + origin="lower", + vmin=kwargs.get("vmin", None), + vmax=kwargs.get("vmax", None), + ) + + for axis in ["top", "bottom", "left", "right"]: + axins.spines[axis].set_linewidth(kwargs.get("zoomed_inset_lw", 1)) + axins.spines[axis].set_color(kwargs.get("zoomed_inset_col", "k")) + + axins.set_xticklabels([]) + axins.set_yticklabels([]) diff --git a/projects/spie_radial_subsampling/predict_test.py b/projects/spie_radial_subsampling/predict_test.py new file mode 100644 index 00000000..1a6a9d6f --- /dev/null +++ b/projects/spie_radial_subsampling/predict_test.py @@ -0,0 +1,123 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors +import functools +import logging +import os +import pathlib +import sys + +import numpy as np +import torch + +import direct.launch +from direct.common.subsample import build_masking_function +from direct.environment import Args +from direct.inference import build_inference_transforms, setup_inference_save_to_h5 +from direct.utils import set_all_seeds + +logger = logging.getLogger(__name__) + + +def _calgary_volume_post_processing_func(volume): + volume = volume / np.sqrt(np.prod(volume.shape[1:])) + return volume + + +def _get_transforms(env): + dataset_cfg = env.cfg.inference.dataset + mask_func = build_masking_function(**dataset_cfg.transforms.masking) + transforms = build_inference_transforms(env, mask_func, dataset_cfg) + return dataset_cfg, transforms + + +if __name__ == "__main__": + # This sets MKL threads to 1. + # DataLoader can otherwise bring a lot of difficulties when computing CPU FFTs in the transforms. + torch.set_num_threads(1) + os.environ["OMP_NUM_THREADS"] = "1" + + epilog = f""" + Examples: + Run on single machine: + $ {sys.argv[0]} data_root output_directory --checkpoint --name [--other-flags] + Run on multiple machines: + (machine0)$ {sys.argv[0]} data_root output_directory --checkpoint --name --machine-rank 0 --num-machines 2 --dist-url [--other-flags] + (machine1)$ {sys.argv[0]} data_root output_directory --checkpoint --name --machine-rank 1 --num-machines 2 --dist-url [--other-flags] + """ + + parser = Args(epilog=epilog) + parser.add_argument("data_root", type=pathlib.Path, help="Path to the data directory.") + parser.add_argument("output_directory", type=pathlib.Path, help="Path to the DoIterationOutput directory.") + parser.add_argument( + "experiment_directory", + type=pathlib.Path, + help="Path to the directory with checkpoints and config.", + ) + parser.add_argument( + "--checkpoint", + type=int, + required=True, + help="Number of an existing checkpoint.", + ) + parser.add_argument( + "--filenames-filter", + type=pathlib.Path, + help="Path to list of filenames to parse.", + ) + parser.add_argument( + "--name", + dest="name", + help="Run name if this is different experiment directory.", + required=False, + type=str, + default="", + ) + parser.add_argument( + "--cfg", + dest="cfg_file", + help="Config file for inference. " + "Only use it to overwrite the standard loading of the config in the project directory.", + required=False, + type=pathlib.Path, + ) + parser.add_argument( + "--use-orthogonal-normalization", + dest="use_orthogonal_normalization", + help="If set, an orthogonal normalization (e.g. ortho in numpy.fft) will be used. " + "The Calgary-Campinas challenge does not use this, therefore the volumes will be" + " normalized to their expected outputs.", + default="store_true", + ) + + args = parser.parse_args() + set_all_seeds(args.seed) + + setup_inference_save_to_h5 = functools.partial( + setup_inference_save_to_h5, + functools.partial(_get_transforms), + ) + + volume_post_processing_func = None + if not args.use_orthogonal_normalization: + volume_post_processing_func = _calgary_volume_post_processing_func + + direct.launch.launch( + setup_inference_save_to_h5, + args.num_machines, + args.num_gpus, + args.machine_rank, + args.dist_url, + args.name, + args.data_root, + args.experiment_directory, + args.output_directory, + args.filenames_filter, + args.checkpoint, + args.device, + args.num_workers, + args.machine_rank, + args.cfg_file, + volume_post_processing_func, + args.mixed_precision, + args.debug, + ) diff --git a/projects/spie_radial_subsampling/predict_val.py b/projects/spie_radial_subsampling/predict_val.py new file mode 100644 index 00000000..209c63ce --- /dev/null +++ b/projects/spie_radial_subsampling/predict_val.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors +import functools +import logging +import os +import pathlib +import sys + +import numpy as np +import torch + +import direct.launch +from direct.common.subsample import build_masking_function +from direct.environment import Args +from direct.inference import build_inference_transforms, setup_inference_save_to_h5 +from direct.utils import set_all_seeds + +logger = logging.getLogger(__name__) + + +def _calgary_volume_post_processing_func(volume): + volume = volume / np.sqrt(np.prod(volume.shape[1:])) + return volume + + +def _get_transforms(validation_index, env): + dataset_cfg = env.cfg.validation.datasets[validation_index] + mask_func = build_masking_function(**dataset_cfg.transforms.masking) + transforms = build_inference_transforms(env, mask_func, dataset_cfg) + return dataset_cfg, transforms + + +if __name__ == "__main__": + # This sets MKL threads to 1. + # DataLoader can otherwise bring a lot of difficulties when computing CPU FFTs in the transforms. + torch.set_num_threads(1) + os.environ["OMP_NUM_THREADS"] = "1" + + epilog = f""" + Examples: + Run on single machine: + $ {sys.argv[0]} data_root output_directory --checkpoint --name [--other-flags] + Run on multiple machines: + (machine0)$ {sys.argv[0]} data_root output_directory --checkpoint --name --machine-rank 0 --num-machines 2 --dist-url [--other-flags] + (machine1)$ {sys.argv[0]} data_root output_directory --checkpoint --name --machine-rank 1 --num-machines 2 --dist-url [--other-flags] + """ + + parser = Args(epilog=epilog) + parser.add_argument("data_root", type=pathlib.Path, help="Path to the data directory.") + parser.add_argument("output_directory", type=pathlib.Path, help="Path to the DoIterationOutput directory.") + parser.add_argument( + "experiment_directory", + type=pathlib.Path, + help="Path to the directory with checkpoints and config.", + ) + parser.add_argument( + "--checkpoint", + type=int, + required=True, + help="Number of an existing checkpoint.", + ) + parser.add_argument( + "--validation-index", + type=int, + required=True, + help="This is the index of the validation set in the config, e.g., 0 will select the first validation set.", + ) + parser.add_argument( + "--filenames-filter", + type=pathlib.Path, + help="Path to list of filenames to parse.", + ) + parser.add_argument( + "--name", + dest="name", + help="Run name if this is different experiment directory.", + required=False, + type=str, + default="", + ) + parser.add_argument( + "--cfg", + dest="cfg_file", + help="Config file for inference. " + "Only use it to overwrite the standard loading of the config in the project directory.", + required=False, + type=pathlib.Path, + ) + parser.add_argument( + "--use-orthogonal-normalization", + dest="use_orthogonal_normalization", + help="If set, an orthogonal normalization (e.g. ortho in numpy.fft) will be used. " + "The Calgary-Campinas challenge does not use this, therefore the volumes will be" + " normalized to their expected outputs.", + default="store_true", + ) + + args = parser.parse_args() + set_all_seeds(args.seed) + + setup_inference_save_to_h5 = functools.partial( + setup_inference_save_to_h5, + functools.partial(_get_transforms, args.validation_index), + ) + volume_post_processing_func = None + if not args.use_orthogonal_normalization: + volume_post_processing_func = _calgary_volume_post_processing_func + + direct.launch.launch( + setup_inference_save_to_h5, + args.num_machines, + args.num_gpus, + args.machine_rank, + args.dist_url, + args.name, + args.data_root, + args.experiment_directory, + args.output_directory, + args.filenames_filter, + args.checkpoint, + args.device, + args.num_workers, + args.machine_rank, + args.cfg_file, + volume_post_processing_func, + args.mixed_precision, + args.debug, + ) diff --git a/projects/toy/base.yaml b/projects/toy/base.yaml index de683a49..49e2d759 100644 --- a/projects/toy/base.yaml +++ b/projects/toy/base.yaml @@ -1,18 +1,12 @@ -# This model is a reproduction with some small changes of our winning algorithm in the Calgary-Campinas challenge -# at MIDL2020. It has an improved sensitivity estimation and better metric logging capabilities based on DIRECT v0.2 -# features. physics: forward_operator: fft2(centered=False) backward_operator: ifft2(centered=False) training: datasets: - # Two datasets, only difference is the shape, so the data can be collated for larger batches - name: FakeMRIBlobs -# lists: -# - ./train.lst - sample_size: 20 - num_coils: 16 - spatial_shape: [32, 40] + sample_size: 3 + num_coils: 8 + spatial_shape: [11, 32, 40] transforms: crop: null estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS @@ -24,7 +18,6 @@ training: center_fractions: [0.2] seed: [0] seed: 10 -# crop_outer_slices: true batch_size: 4 # This is the batch size per GPU! optimizer: Adam lr: 0.0001 @@ -32,13 +25,13 @@ training: lr_step_size: 50000 lr_gamma: 0.2 lr_warmup_iter: 1000 - num_iterations: 100 + num_iterations: 40 gradient_steps: 1 gradient_clipping: 0.0 gradient_debug: false checkpointer: checkpoint_steps: 500 - validation_steps: 2 # With batch size 4 and 4 GPUs this is about 7300 iterations, or ~1 epoch. + validation_steps: 20 loss: crop: null losses: @@ -48,13 +41,10 @@ training: multiplier: 1.0 validation: datasets: - # Twice the same dataset but a different acceleration factor - name: FakeMRIBlobs -# lists: -# - ./val.lst - sample_size: 10 - num_coils: 16 - spatial_shape: [32, 40] + sample_size: 3 + num_coils: 8 + spatial_shape: [11, 32, 40] transforms: crop: null estimate_sensitivity_maps: true @@ -64,7 +54,6 @@ validation: accelerations: [5] center_fractions: [0.2] seed: [0] -# crop_outer_slices: true text_description: 5x # Description for logging seed: 11 batch_size: 4 @@ -87,7 +76,6 @@ inference: batch_size: 4 dataset: name: FakeMRIBlobs -# crop_outer_slices: true text_description: inference transforms: crop: null