Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sampling masks #140

Merged
merged 41 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
afceed4
Missing f in fstring
Jun 29, 2021
fa4872d
Adding radial mask
Jun 29, 2021
76ba218
Adding radial mask
Jun 29, 2021
06dd56d
Adding radial mask
Jun 29, 2021
d7970b3
Adding radial mask
Jun 29, 2021
fd874fe
Radial subsampling update
Jun 30, 2021
e904d62
Fix evaluate to use last volume's last slice
georgeyiasemis Jul 5, 2021
ecba93c
Black fix
georgeyiasemis Jul 6, 2021
b4f2c21
Merge branch 'fix-evaluate-function' into add-sampling-masks
georgeyiasemis Jul 6, 2021
5c29170
Change toy base.yaml
georgeyiasemis Jul 6, 2021
aef3d86
direct/data/tests/
georgeyiasemis Jul 6, 2021
f753aed
Minor naming fix
georgeyiasemis Jul 6, 2021
543257b
Minor naming fix
georgeyiasemis Jul 6, 2021
c22d0c2
Fix h5slicedata logger
georgeyiasemis Jul 13, 2021
8e0c50b
Radial mask fixed to return acs
georgeyiasemis Jul 13, 2021
84e2ab5
Radial mask fixed to return acs
georgeyiasemis Jul 13, 2021
4dc7b5b
Fix to overwrite standard loading of config and add predict function
georgeyiasemis Jul 22, 2021
b3b9a90
Fix subsampling seed & test it
georgeyiasemis Jul 27, 2021
3666f7f
Removed hashed code
georgeyiasemis Jul 27, 2021
b1f61f0
Added rectilinear masking for calgary campinas dataset
georgeyiasemis Jul 27, 2021
7cd0084
Cleaned calgary project and added scripts for radial subsampling project
Aug 4, 2021
ce025cc
Removing not needed project
Aug 4, 2021
d78e285
Fixed evaluation function
georgeyiasemis Aug 26, 2021
b7a0e4f
Merge branch 'fix-evaluate-function' into add-sampling-masks
georgeyiasemis Aug 26, 2021
846a65d
Rebase from main
georgeyiasemis Sep 3, 2021
5220236
Merge remote-tracking branch 'origin/main' into add-sampling-mask
georgeyiasemis Dec 13, 2021
b28aa51
Remove calgary-campinas rectilinear masks. Can use fastMRI random mask
georgeyiasemis Dec 13, 2021
8a9bb40
Fix black
georgeyiasemis Dec 13, 2021
4ca542d
Fix black
georgeyiasemis Dec 13, 2021
a5a5f13
Rename
georgeyiasemis Dec 13, 2021
4bd15ff
Changed maskfunc
georgeyiasemis Dec 13, 2021
3754b64
Replace calgary-rectilinear masks with fastmrirandom
georgeyiasemis Dec 13, 2021
3102637
Rename
georgeyiasemis Dec 13, 2021
a2027d7
Merge pull request #145 from directgroup/main
georgeyiasemis Dec 15, 2021
a665e0a
Remove named tensors comments
georgeyiasemis Dec 15, 2021
89ca492
Minor fixes
georgeyiasemis Dec 15, 2021
8cdbe22
Merge pull request #147 from directgroup/main
georgeyiasemis Dec 16, 2021
395e534
Minor fixes
georgeyiasemis Dec 16, 2021
f8f6a28
Update reference style
georgeyiasemis Dec 16, 2021
9a54c26
Changed subsampling_scheme type
georgeyiasemis Dec 16, 2021
c9fdd21
isort fixes
georgeyiasemis Dec 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 267 additions & 10 deletions direct/common/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@

import contextlib
import logging
import pathlib
from abc import abstractmethod
from enum import Enum
from typing import List, Optional, Tuple

import numpy as np
import torch

import direct.data.transforms as T
from direct.environment import DIRECT_CACHE_DIR
from direct.types import Number
from direct.utils import str_to_class
from direct.utils.io import download_url

logger = logging.getLogger(__name__)
GOLDEN_RATIO = (1 + np.sqrt(5)) / 2


@contextlib.contextmanager
Expand Down Expand Up @@ -104,7 +106,7 @@ def __call__(self, data, seed=None, return_acs=False):
ndarray
"""
self.rng.seed(seed)
mask = self.mask_func(data, return_acs=return_acs) # pylint: disable = E1123
mask = self.mask_func(data, seed=seed, return_acs=return_acs) # pylint: disable = E1123
return mask


Expand Down Expand Up @@ -145,7 +147,7 @@ def mask_func(self, shape, return_acs=False, seed=None):
Parameters
----------

data : iterable[int]
shape : iterable[int]
The shape of the mask to be created. The shape should at least 3 dimensions.
Samples are drawn along the second last dimension.
seed : int (optional)
Expand Down Expand Up @@ -230,7 +232,7 @@ def mask_func(self, shape, return_acs=False, seed=None):
Parameters
----------

data : iterable[int]
shape : iterable[int]
The shape of the mask to be created. The shape should at least 3 dimensions.
Samples are drawn along the second last dimension.
seed : int (optional)
Expand Down Expand Up @@ -314,20 +316,43 @@ def circular_centered_mask(shape, radius):
mask = ((dist_from_center <= radius) * np.ones(shape)).astype(bool)
return mask[np.newaxis, ..., np.newaxis]

def mask_func(self, shape, return_acs=False):
def mask_func(self, shape, return_acs=False, seed=None):
"""
Downloads and loads pre=computed Poisson masks.
Currently supports shapes of 218 x 170/ 218/ 174 and acceleration factors of 5 or 10.

Parameters
----------

shape : iterable[int]
The shape of the mask to be created. The shape should at least 3 dimensions.
Samples are drawn along the second last dimension.
seed : int (optional)
Seed for the random number generator. Setting the seed ensures the same mask is generated
each time for the same shape.
return_acs : bool
Return the autocalibration signal region as a mask.

Returns
-------
torch.Tensor : the sampling mask

"""
shape = tuple(shape)[:-1]
if return_acs:
return torch.from_numpy(self.circular_centered_mask(shape, 18))

if shape not in self.shapes:
raise ValueError(f"No mask of shape {shape} is available in the CalgaryCampinas dataset.")

acceleration = self.choose_acceleration()
masks = self.masks[acceleration]
with temp_seed(self.rng, seed):
acceleration = self.choose_acceleration()
masks = self.masks[acceleration]

mask, num_masks = masks[shape]
# Randomly pick one example
choice = self.rng.randint(0, num_masks)

mask, num_masks = masks[shape]
# Randomly pick one example
choice = self.rng.randint(0, num_masks)
return torch.from_numpy(mask[choice][np.newaxis, ..., np.newaxis])

def __load_masks(self, acceleration):
Expand All @@ -352,6 +377,238 @@ def __load_masks(self, acceleration):
return output


class CIRCUSSamplingMode(str, Enum):

circus_radial = "circus-radial"
circus_spiral = "circus-spiral"


class CIRCUSMaskFunc(BaseMaskFunc):
"""
Implementation of Cartesian undersampling (radial or spiral) using CIRCUS as shown in [1]_. It creates
radial or spiral masks for Cartesian acquired data on a grid.

References
----------

.. [1] Liu J, Saloner D. Accelerated MRI with CIRcular Cartesian UnderSampling (CIRCUS):
a variable density Cartesian sampling strategy for compressed sensing and parallel imaging.
Quant Imaging Med Surg. 2014 Feb;4(1):57-67. doi: 10.3978/j.issn.2223-4292.2014.02.01.
PMID: 24649436; PMCID: PMC3947985.

"""

def __init__(
self,
accelerations,
subsampling_scheme: CIRCUSSamplingMode,
**kwargs,
):
super().__init__(
accelerations=accelerations,
center_fractions=tuple(0 for _ in range(len(accelerations))),
uniform_range=False,
)
if subsampling_scheme not in ["circus-spiral", "circus-radial"]:
raise NotImplementedError(
f"Currently CIRCUSMaskFunc is only implemented for 'circus-radial' or 'circus-spiral' "
f"as a subsampling_scheme. Got subsampling_scheme={subsampling_scheme}."
)

self.subsampling_scheme = "circus-radial" if subsampling_scheme is None else subsampling_scheme

@staticmethod
def get_square_ordered_idxs(square_side_size: int, square_id: int) -> Tuple[Tuple, ...]:
"""
Returns ordered (clockwise) indices of a sub-square of a square matrix.

Parameters:
-----------
square_side_size: int
Square side size. Dim of array.
square_id: int
Number of sub-square. Can be 0, ..., square_side_size // 2.

Returns:
--------
ordered_idxs: List of tuples.
Indices of each point that belongs to the square_id-th sub-square
starting from top-left point clockwise.

"""
assert square_id in range(square_side_size // 2)

ordered_idxs = list()

for col in range(square_id, square_side_size - square_id):
ordered_idxs.append((square_id, col))

for row in range(square_id + 1, square_side_size - (square_id + 1)):
ordered_idxs.append((row, square_side_size - (square_id + 1)))

for col in range(square_side_size - (square_id + 1), square_id, -1):
ordered_idxs.append((square_side_size - (square_id + 1), col))

for row in range(square_side_size - (square_id + 1), square_id, -1):
ordered_idxs.append((row, square_id))

return tuple(ordered_idxs)

def circus_radial_mask(self, shape, acceleration):
"""
Implements CIRCUS radial undersampling.
"""
max_dim = max(shape) - max(shape) % 2
min_dim = min(shape) - min(shape) % 2
num_nested_squares = max_dim // 2
M = int(np.prod(shape) / (acceleration * (max_dim / 2 - (max_dim - min_dim) * (1 + min_dim / max_dim) / 4)))

mask = np.zeros((max_dim, max_dim), dtype=np.float32)

t = self.rng.randint(low=0, high=1e4, size=1, dtype=int).item()

for square_id in range(num_nested_squares):
ordered_indices = self.get_square_ordered_idxs(
square_side_size=max_dim,
square_id=square_id,
)
# J: size of the square, J=2,…,N, i.e., the number of points along one side of the square
J = 2 * (num_nested_squares - square_id)
# K: total number of points along the perimeter of the square K=4·J-4;
K = 4 * (J - 1)

for m in range(M):
indices_idx = int(np.floor(np.mod((m + t * M) / GOLDEN_RATIO, 1) * K))
mask[ordered_indices[indices_idx]] = 1.0

pad = ((shape[0] % 2, 0), (shape[1] % 2, 0))

mask = np.pad(mask, pad, constant_values=0)
mask = T.center_crop(torch.from_numpy(mask.astype(bool)), shape)

return mask

def circus_spiral_mask(self, shape, acceleration):
"""
Implements CIRCUS spiral undersampling.
"""
max_dim = max(shape) - max(shape) % 2
min_dim = min(shape) - min(shape) % 2

num_nested_squares = max_dim // 2

M = int(np.prod(shape) / (acceleration * (max_dim / 2 - (max_dim - min_dim) * (1 + min_dim / max_dim) / 4)))

mask = np.zeros((max_dim, max_dim), dtype=np.float32)

c = self.rng.uniform(low=1.1, high=1.3, size=1).item()

for square_id in range(num_nested_squares):

ordered_indices = self.get_square_ordered_idxs(
square_side_size=max_dim,
square_id=square_id,
)

# J: size of the square, J=2,…,N, i.e., the number of points along one side of the square
J = 2 * (num_nested_squares - square_id)
# K: total number of points along the perimeter of the square K=4·J-4;
K = 4 * (J - 1)

for m in range(M):
i = np.floor(np.mod(m / GOLDEN_RATIO, 1) * K)
indices_idx = int(np.mod((i + np.ceil(J ** c) - 1), K))

mask[ordered_indices[indices_idx]] = 1.0

pad = ((shape[0] % 2, 0), (shape[1] % 2, 0))

mask = np.pad(mask, pad)
mask = T.center_crop(torch.from_numpy(mask.astype(bool)), shape)

return mask

@staticmethod
def circular_centered_mask(mask, eps=0.1):
shape = mask.shape
center = np.asarray(shape) // 2
Y, X = np.ogrid[: shape[0], : shape[1]]
Y, X = torch.tensor(Y), torch.tensor(X)
radius = 1

# Finds the maximum (unmasked) disk in mask given a tolerance.
while True:
# Creates a disk with R=radius and finds intersection with mask
disk = (Y - center[0]) ** 2 + (X - center[1]) ** 2 <= radius ** 2
intersection = disk & mask
ratio = disk.sum() / intersection.sum()
if ratio > 1.0 + eps:
return intersection
radius += eps

def mask_func(self, shape, return_acs=False, seed=None):

if len(shape) < 3:
raise ValueError("Shape should have 3 or more dimensions")

with temp_seed(self.rng, seed):
num_rows = shape[-3]
num_cols = shape[-2]
acceleration = self.choose_acceleration()[1]

if self.subsampling_scheme == "circus-radial":
mask = self.circus_radial_mask(
shape=(num_rows, num_cols),
acceleration=acceleration,
)
elif self.subsampling_scheme == "circus-spiral":
mask = self.circus_spiral_mask(
shape=(num_rows, num_cols),
acceleration=acceleration,
)

if return_acs:
return self.circular_centered_mask(mask).unsqueeze(0).unsqueeze(-1)

return mask.unsqueeze(0).unsqueeze(-1)


class RadialMaskFunc(CIRCUSMaskFunc):
"""
Computes radial masks for Cartesian data.

"""

def __init__(
self,
accelerations,
**kwargs,
):
super().__init__(
accelerations=accelerations,
subsampling_scheme=CIRCUSSamplingMode.circus_radial,
**kwargs,
)


class SpiralMaskFunc(CIRCUSMaskFunc):
"""
Computes spiral masks for Cartesian data.

"""

def __init__(
self,
accelerations,
**kwargs,
):
super().__init__(
accelerations=accelerations,
subsampling_scheme=CIRCUSSamplingMode.circus_spiral,
**kwargs,
)


class DictionaryMaskFunc(BaseMaskFunc):
def __init__(self, data_dictionary, **kwargs): # noqa
super().__init__(accelerations=None)
Expand Down
8 changes: 4 additions & 4 deletions direct/common/subsample_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
@dataclass
class MaskingConfig(BaseConfig):
name: str = MISSING
accelerations: Tuple[int, ...] = (4,) # Ideally Union[float, int].
center_fractions: Optional[Tuple[float, ...]] = (0.08,) # Ideally Optional[Tuple[float, ...]]
accelerations: Tuple[int, ...] = (5,) # Ideally Union[float, int].
center_fractions: Optional[Tuple[float, ...]] = (0.1,) # Ideally Optional[Tuple[float, ...]]
uniform_range: bool = False
image_center_crop: bool = False

val_accelerations: Tuple[int, ...] = (4, 8)
val_center_fractions: Optional[Tuple[float, ...]] = (0.08, 0.04)
val_accelerations: Tuple[int, ...] = (5, 10)
val_center_fractions: Optional[Tuple[float, ...]] = (0.1, 0.05)
Loading