Skip to content

Commit

Permalink
Add non-Cartesian masking functions (#140)
Browse files Browse the repository at this point in the history
* Adding non-Cartesian masking functions for cartesian data (simulating non-Cartesian trajectory on data acquired on a grid)
* Scripts for reproducing experiments as presented in "Deep MRI Reconstruction with radial subsampling" (https://arxiv.org/abs/2108.07619)
  • Loading branch information
georgeyiasemis authored Dec 16, 2021
1 parent 2fbbbb2 commit 1a1f93d
Show file tree
Hide file tree
Showing 38 changed files with 1,727 additions and 118 deletions.
277 changes: 267 additions & 10 deletions direct/common/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@

import contextlib
import logging
import pathlib
from abc import abstractmethod
from enum import Enum
from typing import List, Optional, Tuple

import numpy as np
import torch

import direct.data.transforms as T
from direct.environment import DIRECT_CACHE_DIR
from direct.types import Number
from direct.utils import str_to_class
from direct.utils.io import download_url

logger = logging.getLogger(__name__)
GOLDEN_RATIO = (1 + np.sqrt(5)) / 2


@contextlib.contextmanager
Expand Down Expand Up @@ -104,7 +106,7 @@ def __call__(self, data, seed=None, return_acs=False):
ndarray
"""
self.rng.seed(seed)
mask = self.mask_func(data, return_acs=return_acs) # pylint: disable = E1123
mask = self.mask_func(data, seed=seed, return_acs=return_acs) # pylint: disable = E1123
return mask


Expand Down Expand Up @@ -145,7 +147,7 @@ def mask_func(self, shape, return_acs=False, seed=None):
Parameters
----------
data : iterable[int]
shape : iterable[int]
The shape of the mask to be created. The shape should at least 3 dimensions.
Samples are drawn along the second last dimension.
seed : int (optional)
Expand Down Expand Up @@ -230,7 +232,7 @@ def mask_func(self, shape, return_acs=False, seed=None):
Parameters
----------
data : iterable[int]
shape : iterable[int]
The shape of the mask to be created. The shape should at least 3 dimensions.
Samples are drawn along the second last dimension.
seed : int (optional)
Expand Down Expand Up @@ -314,20 +316,43 @@ def circular_centered_mask(shape, radius):
mask = ((dist_from_center <= radius) * np.ones(shape)).astype(bool)
return mask[np.newaxis, ..., np.newaxis]

def mask_func(self, shape, return_acs=False):
def mask_func(self, shape, return_acs=False, seed=None):
"""
Downloads and loads pre=computed Poisson masks.
Currently supports shapes of 218 x 170/ 218/ 174 and acceleration factors of 5 or 10.
Parameters
----------
shape : iterable[int]
The shape of the mask to be created. The shape should at least 3 dimensions.
Samples are drawn along the second last dimension.
seed : int (optional)
Seed for the random number generator. Setting the seed ensures the same mask is generated
each time for the same shape.
return_acs : bool
Return the autocalibration signal region as a mask.
Returns
-------
torch.Tensor : the sampling mask
"""
shape = tuple(shape)[:-1]
if return_acs:
return torch.from_numpy(self.circular_centered_mask(shape, 18))

if shape not in self.shapes:
raise ValueError(f"No mask of shape {shape} is available in the CalgaryCampinas dataset.")

acceleration = self.choose_acceleration()
masks = self.masks[acceleration]
with temp_seed(self.rng, seed):
acceleration = self.choose_acceleration()
masks = self.masks[acceleration]

mask, num_masks = masks[shape]
# Randomly pick one example
choice = self.rng.randint(0, num_masks)

mask, num_masks = masks[shape]
# Randomly pick one example
choice = self.rng.randint(0, num_masks)
return torch.from_numpy(mask[choice][np.newaxis, ..., np.newaxis])

def __load_masks(self, acceleration):
Expand All @@ -352,6 +377,238 @@ def __load_masks(self, acceleration):
return output


class CIRCUSSamplingMode(str, Enum):

circus_radial = "circus-radial"
circus_spiral = "circus-spiral"


class CIRCUSMaskFunc(BaseMaskFunc):
"""
Implementation of Cartesian undersampling (radial or spiral) using CIRCUS as shown in [1]_. It creates
radial or spiral masks for Cartesian acquired data on a grid.
References
----------
.. [1] Liu J, Saloner D. Accelerated MRI with CIRcular Cartesian UnderSampling (CIRCUS):
a variable density Cartesian sampling strategy for compressed sensing and parallel imaging.
Quant Imaging Med Surg. 2014 Feb;4(1):57-67. doi: 10.3978/j.issn.2223-4292.2014.02.01.
PMID: 24649436; PMCID: PMC3947985.
"""

def __init__(
self,
accelerations,
subsampling_scheme: CIRCUSSamplingMode,
**kwargs,
):
super().__init__(
accelerations=accelerations,
center_fractions=tuple(0 for _ in range(len(accelerations))),
uniform_range=False,
)
if subsampling_scheme not in ["circus-spiral", "circus-radial"]:
raise NotImplementedError(
f"Currently CIRCUSMaskFunc is only implemented for 'circus-radial' or 'circus-spiral' "
f"as a subsampling_scheme. Got subsampling_scheme={subsampling_scheme}."
)

self.subsampling_scheme = "circus-radial" if subsampling_scheme is None else subsampling_scheme

@staticmethod
def get_square_ordered_idxs(square_side_size: int, square_id: int) -> Tuple[Tuple, ...]:
"""
Returns ordered (clockwise) indices of a sub-square of a square matrix.
Parameters:
-----------
square_side_size: int
Square side size. Dim of array.
square_id: int
Number of sub-square. Can be 0, ..., square_side_size // 2.
Returns:
--------
ordered_idxs: List of tuples.
Indices of each point that belongs to the square_id-th sub-square
starting from top-left point clockwise.
"""
assert square_id in range(square_side_size // 2)

ordered_idxs = list()

for col in range(square_id, square_side_size - square_id):
ordered_idxs.append((square_id, col))

for row in range(square_id + 1, square_side_size - (square_id + 1)):
ordered_idxs.append((row, square_side_size - (square_id + 1)))

for col in range(square_side_size - (square_id + 1), square_id, -1):
ordered_idxs.append((square_side_size - (square_id + 1), col))

for row in range(square_side_size - (square_id + 1), square_id, -1):
ordered_idxs.append((row, square_id))

return tuple(ordered_idxs)

def circus_radial_mask(self, shape, acceleration):
"""
Implements CIRCUS radial undersampling.
"""
max_dim = max(shape) - max(shape) % 2
min_dim = min(shape) - min(shape) % 2
num_nested_squares = max_dim // 2
M = int(np.prod(shape) / (acceleration * (max_dim / 2 - (max_dim - min_dim) * (1 + min_dim / max_dim) / 4)))

mask = np.zeros((max_dim, max_dim), dtype=np.float32)

t = self.rng.randint(low=0, high=1e4, size=1, dtype=int).item()

for square_id in range(num_nested_squares):
ordered_indices = self.get_square_ordered_idxs(
square_side_size=max_dim,
square_id=square_id,
)
# J: size of the square, J=2,…,N, i.e., the number of points along one side of the square
J = 2 * (num_nested_squares - square_id)
# K: total number of points along the perimeter of the square K=4·J-4;
K = 4 * (J - 1)

for m in range(M):
indices_idx = int(np.floor(np.mod((m + t * M) / GOLDEN_RATIO, 1) * K))
mask[ordered_indices[indices_idx]] = 1.0

pad = ((shape[0] % 2, 0), (shape[1] % 2, 0))

mask = np.pad(mask, pad, constant_values=0)
mask = T.center_crop(torch.from_numpy(mask.astype(bool)), shape)

return mask

def circus_spiral_mask(self, shape, acceleration):
"""
Implements CIRCUS spiral undersampling.
"""
max_dim = max(shape) - max(shape) % 2
min_dim = min(shape) - min(shape) % 2

num_nested_squares = max_dim // 2

M = int(np.prod(shape) / (acceleration * (max_dim / 2 - (max_dim - min_dim) * (1 + min_dim / max_dim) / 4)))

mask = np.zeros((max_dim, max_dim), dtype=np.float32)

c = self.rng.uniform(low=1.1, high=1.3, size=1).item()

for square_id in range(num_nested_squares):

ordered_indices = self.get_square_ordered_idxs(
square_side_size=max_dim,
square_id=square_id,
)

# J: size of the square, J=2,…,N, i.e., the number of points along one side of the square
J = 2 * (num_nested_squares - square_id)
# K: total number of points along the perimeter of the square K=4·J-4;
K = 4 * (J - 1)

for m in range(M):
i = np.floor(np.mod(m / GOLDEN_RATIO, 1) * K)
indices_idx = int(np.mod((i + np.ceil(J ** c) - 1), K))

mask[ordered_indices[indices_idx]] = 1.0

pad = ((shape[0] % 2, 0), (shape[1] % 2, 0))

mask = np.pad(mask, pad)
mask = T.center_crop(torch.from_numpy(mask.astype(bool)), shape)

return mask

@staticmethod
def circular_centered_mask(mask, eps=0.1):
shape = mask.shape
center = np.asarray(shape) // 2
Y, X = np.ogrid[: shape[0], : shape[1]]
Y, X = torch.tensor(Y), torch.tensor(X)
radius = 1

# Finds the maximum (unmasked) disk in mask given a tolerance.
while True:
# Creates a disk with R=radius and finds intersection with mask
disk = (Y - center[0]) ** 2 + (X - center[1]) ** 2 <= radius ** 2
intersection = disk & mask
ratio = disk.sum() / intersection.sum()
if ratio > 1.0 + eps:
return intersection
radius += eps

def mask_func(self, shape, return_acs=False, seed=None):

if len(shape) < 3:
raise ValueError("Shape should have 3 or more dimensions")

with temp_seed(self.rng, seed):
num_rows = shape[-3]
num_cols = shape[-2]
acceleration = self.choose_acceleration()[1]

if self.subsampling_scheme == "circus-radial":
mask = self.circus_radial_mask(
shape=(num_rows, num_cols),
acceleration=acceleration,
)
elif self.subsampling_scheme == "circus-spiral":
mask = self.circus_spiral_mask(
shape=(num_rows, num_cols),
acceleration=acceleration,
)

if return_acs:
return self.circular_centered_mask(mask).unsqueeze(0).unsqueeze(-1)

return mask.unsqueeze(0).unsqueeze(-1)


class RadialMaskFunc(CIRCUSMaskFunc):
"""
Computes radial masks for Cartesian data.
"""

def __init__(
self,
accelerations,
**kwargs,
):
super().__init__(
accelerations=accelerations,
subsampling_scheme=CIRCUSSamplingMode.circus_radial,
**kwargs,
)


class SpiralMaskFunc(CIRCUSMaskFunc):
"""
Computes spiral masks for Cartesian data.
"""

def __init__(
self,
accelerations,
**kwargs,
):
super().__init__(
accelerations=accelerations,
subsampling_scheme=CIRCUSSamplingMode.circus_spiral,
**kwargs,
)


class DictionaryMaskFunc(BaseMaskFunc):
def __init__(self, data_dictionary, **kwargs): # noqa
super().__init__(accelerations=None)
Expand Down
8 changes: 4 additions & 4 deletions direct/common/subsample_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
@dataclass
class MaskingConfig(BaseConfig):
name: str = MISSING
accelerations: Tuple[int, ...] = (4,) # Ideally Union[float, int].
center_fractions: Optional[Tuple[float, ...]] = (0.08,) # Ideally Optional[Tuple[float, ...]]
accelerations: Tuple[int, ...] = (5,) # Ideally Union[float, int].
center_fractions: Optional[Tuple[float, ...]] = (0.1,) # Ideally Optional[Tuple[float, ...]]
uniform_range: bool = False
image_center_crop: bool = False

val_accelerations: Tuple[int, ...] = (4, 8)
val_center_fractions: Optional[Tuple[float, ...]] = (0.08, 0.04)
val_accelerations: Tuple[int, ...] = (5, 10)
val_center_fractions: Optional[Tuple[float, ...]] = (0.1, 0.05)
Loading

0 comments on commit 1a1f93d

Please sign in to comment.