diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index a77d20f7..301611ca 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -21,5 +21,6 @@ jobs: run: | python -m pip install --upgrade pip setuptools wheel pip install tox tox-gh-actions + python -m pip install -e ".[dev]" - name: Test with tox run: tox diff --git a/MANIFEST.in b/MANIFEST.in index d2200584..0547420e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,6 +5,7 @@ include LICENSE include README.md recursive-include tests * +recursive-include direct *.pyx *.pxd *.pxi *.py *.h *.ini *.npy *.txt *.in *.md recursive-exclude * __pycache__ recursive-exclude * *.py[co] diff --git a/Makefile b/Makefile index 5fd26e88..21291039 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: clean clean-test clean-pyc clean-build docs help +.PHONY: clean clean-test clean-pyc clean-cpy clean-build docs help .DEFAULT_GOAL := help define BROWSER_PYSCRIPT @@ -26,7 +26,7 @@ BROWSER := python -c "$$BROWSER_PYSCRIPT" help: @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) -clean: clean-build clean-pyc clean-test clean-docs ## remove all build, test, coverage, docs and Python artifacts +clean: clean-build clean-pyc clean-cpy clean-test clean-docs ## remove all build, test, coverage, docs and Python and cython artifacts clean-build: ## remove build artifacts rm -fr build/ @@ -41,6 +41,11 @@ clean-pyc: ## remove Python file artifacts find . -name '*~' -exec rm -f {} + find . -name '__pycache__' -exec rm -fr {} + +clean-cpy: ## remove cython file artifacts + find . -name '*.c' -exec rm -f {} + + find . -name '*.cpp' -exec rm -f {} + + find . -name '*.so' -exec rm -f {} + + clean-test: ## remove test and coverage artifacts rm -fr .tox/ rm -f .coverage diff --git a/direct/common/_poisson.pyx b/direct/common/_poisson.pyx new file mode 100644 index 00000000..5b1ad183 --- /dev/null +++ b/direct/common/_poisson.pyx @@ -0,0 +1,123 @@ +#cython: cdivision=True +#cython: boundscheck=False +#cython: nonecheck=False +#cython: wraparound=False +#cython: overflowcheck=False +#cython: unraisable_tracebacks=False + +import numpy as np +cimport numpy as cnp +from libc.math cimport cos, pi, sin +from libc.stdlib cimport RAND_MAX, rand, srand + +cnp.import_array() + + +cdef double random_uniform() nogil: + """Produces a random number in (0, 1).""" + cdef double r = rand() + return r / RAND_MAX + + +cdef int randint(int upper) nogil: + """Produces a random integer in {0, 1, ..., upper-1}.""" + return int(random_uniform() * (upper)) + + +cdef inline Py_ssize_t fmax(Py_ssize_t one, Py_ssize_t two) nogil: + """Max(a, b).""" + return one if one > two else two + + +cdef inline Py_ssize_t fmin(Py_ssize_t one, Py_ssize_t two) nogil: + """Min(a, b).""" + return one if one < two else two + + +def poisson( + int nx, + int ny, + int max_attempts, + cnp.ndarray[cnp.int_t, ndim=2, mode='c'] mask, + cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] radius_x, + cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] radius_y, + int seed +): + """ + Notes + ----- + + * Code inspired and modified from [1]_ with BSD-3 licence, Copyright (c) 2016, Frank Ong, Copyright (c) 2016, + The Regents of the University of California [2]_. + + References + ---------- + + .. [1] https://github.com/mikgroup/sigpy/blob/1817ff849d34d7cbbbcb503a1b310e7d8f95c242/sigpy/mri/samp.py#L158 + .. [2] https://github.com/mikgroup/sigpy/blob/master/LICENSE + """ + + cdef int x, y, num_actives, i, k + cdef float rx, ry, v, t, qx, qy, distance + cdef Py_ssize_t startx, endx, starty, endy, px, py + + # initialize active list + cdef cnp.ndarray[cnp.int_t, ndim=1, mode='c'] pxs = np.empty(nx * ny, dtype=int) + cdef cnp.ndarray[cnp.int_t, ndim=1, mode='c'] pys = np.empty(nx * ny, dtype=int) + + srand(seed) + + with nogil: + + pxs[0] = randint(nx) + pys[0] = randint(ny) + + num_actives = 1 + + while num_actives > 0: + # Select a sample from active list + i = randint(num_actives) + px = pxs[i] + py = pys[i] + rx = radius_x[px, py] + ry = radius_y[px, py] + + # Attempt to generate point + done = False + k = 0 + + while not done and k < max_attempts: + + # Generate point randomly from r and 2 * r + v = random_uniform() + 1 + t = 2 * pi * random_uniform() + qx = px + v * rx * cos(t) + qy = py + v * ry * sin(t) + + # Reject if outside grid or close to other points + if qx >= 0 and qx < nx and qy >= 0 and qy < ny: + startx = fmax(int(qx - rx), 0) + endx = fmin(int(qx + rx + 1), nx) + starty = fmax(int(qy - ry), 0) + endy = fmin(int(qy + ry + 1), ny) + + done = True + for x in range(startx, endx): + for y in range(starty, endy): + distance = ((qx - x) / radius_x[x, y]) ** 2 + ((qy - y) / (radius_y[x, y])) ** 2 + if (mask[x, y] == 1) and (distance < 1): + done = False + break + + k += 1 + + # Add point if done else remove from active list + if done: + pxs[num_actives] = int(qx) + pys[num_actives] = int(qy) + mask[pxs[num_actives], pys[num_actives]] = 1 + num_actives += 1 + else: + num_actives -= 1 + pxs[i] = pxs[num_actives] + pys[i] = pys[num_actives] diff --git a/direct/common/subsample.py b/direct/common/subsample.py index cb25cd39..9b5988b0 100644 --- a/direct/common/subsample.py +++ b/direct/common/subsample.py @@ -11,17 +11,31 @@ import logging from abc import abstractmethod from enum import Enum -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import numpy as np import torch import direct.data.transforms as T +from direct.common._poisson import poisson as _poisson # pylint: disable=no-name-in-module from direct.environment import DIRECT_CACHE_DIR from direct.types import Number from direct.utils import str_to_class from direct.utils.io import download_url +# pylint: disable=arguments-differ + +__all__ = ( + "FastMRIRandomMaskFunc", + "FastMRIEquispacedMaskFunc", + "FastMRIMagicMaskFunc", + "CalgaryCampinasMaskFunc", + "RadialMaskFunc", + "SpiralMaskFunc", + "VariableDensityPoissonMaskFunc", + "build_masking_function", +) + logger = logging.getLogger(__name__) GOLDEN_RATIO = (1 + np.sqrt(5)) / 2 @@ -41,22 +55,22 @@ class BaseMaskFunc: def __init__( self, - accelerations: Optional[Tuple[Number, ...]], - center_fractions: Optional[Tuple[float, ...]] = None, + accelerations: Union[List[Number], Tuple[Number, ...]], + center_fractions: Optional[Union[List[float], Tuple[float, ...]]] = None, uniform_range: bool = True, ): """ Parameters ---------- - center_fractions: List([float]) + accelerations: Union[List[Number], Tuple[Number, ...]] + Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by + mask_type. Has to be the same length as center_fractions if uniform_range is not True. + center_fractions: Optional[Union[List[float], Tuple[float, ...]]] Fraction of low-frequency columns to be retained. If multiple values are provided, then one of these numbers is chosen uniformly each time. If uniform_range - is True, then two values should be given. - accelerations: List([int]) - Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by - mask_type. Has to be the same length as center_fractions if uniform_range is True. + is True, then two values should be given. Default: None. uniform_range: bool - If True then an acceleration will be uniformly sampled between the two values. + If True then an acceleration will be uniformly sampled between the two values. Default: True. """ if center_fractions is not None: if len([center_fractions]) != len([accelerations]): @@ -107,11 +121,11 @@ def __call__(self, *args, **kwargs) -> torch.Tensor: return mask -class FastMRIRandomMaskFunc(BaseMaskFunc): +class FastMRIMaskFunc(BaseMaskFunc): def __init__( self, - accelerations: Tuple[Number, ...], - center_fractions: Optional[Tuple[float, ...]] = None, + accelerations: Union[List[Number], Tuple[Number, ...]], + center_fractions: Optional[Union[List[float], Tuple[float, ...]]] = None, uniform_range: bool = False, ): super().__init__( @@ -120,36 +134,86 @@ def __init__( uniform_range=uniform_range, ) - def mask_func(self, shape, return_acs=False, seed=None): - r"""Creates vertical line mask. + @staticmethod + def center_mask_func(num_cols, num_low_freqs): + + # create the mask + mask = np.zeros(num_cols, dtype=bool) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True - The mask selects a subset of columns from the input k-space data. If the k-space data has N - columns, the mask picks out: + return mask + + @staticmethod + def _reshape_and_broadcast_mask(shape, mask): + num_cols = shape[-2] + num_rows = shape[-3] + + # Reshape the mask + mask_shape = [1 for _ in shape] + mask_shape[-2] = num_cols + mask = mask.reshape(*mask_shape).astype(bool) + mask_shape[-3] = num_rows + + # Add coil axis, make array writable. + mask = np.broadcast_to(mask, mask_shape)[np.newaxis, ...].copy() + + return mask - #. :math:`N_{\text{low freqs}} = (N \times \text{center_fraction})` columns in the center corresponding - to low-frequencies. - #. The other columns are selected uniformly at random with a probability equal to: - :math:`\text{prob} = (N / \text{acceleration} - N_{\text{low freqs}}) / (N - N_{\text{low freqs}})`. - This ensures that the expected number of columns selected is equal to (N / acceleration). - It is possible to use multiple center_fractions and accelerations, in which case one possible - (center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is - called. +class FastMRIRandomMaskFunc(FastMRIMaskFunc): + r"""Random vertical line mask function. + + The mask selects a subset of columns from the input k-space data. If the k-space data has :math:`N` columns, + the mask picks out: + + #. :math:`N_{\text{low freqs}} = (N \times \text{center_fraction})` columns in the center corresponding + to low-frequencies. + #. The other columns are selected uniformly at random with a probability equal to: + :math:`\text{prob} = (N / \text{acceleration} - N_{\text{low freqs}}) / (N - N_{\text{low freqs}})`. + This ensures that the expected number of columns selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which case one possible + (center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is + called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there + is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% + probability that 8-fold acceleration with 4% center fraction is selected. + + """ + + def __init__( + self, + accelerations: Union[List[Number], Tuple[Number, ...]], + center_fractions: Optional[Union[List[float], Tuple[float, ...]]] = None, + uniform_range: bool = False, + ): + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + ) - For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there - is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% - probability that 8-fold acceleration with 4% center fraction is selected. + def mask_func( + self, + shape: Union[List[int], Tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: + """Creates vertical line mask. Parameters ---------- - shape: iterable[int] + shape: list or tuple of ints The shape of the mask to be created. The shape should at least 3 dimensions. Samples are drawn along the second last dimension. - seed: int (optional) - Seed for the random number generator. Setting the seed ensures the same mask is generated - each time for the same shape. return_acs: bool Return the autocalibration signal region as a mask. + seed: int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. + Returns ------- @@ -160,39 +224,47 @@ def mask_func(self, shape, return_acs=False, seed=None): raise ValueError("Shape should have 3 or more dimensions") with temp_seed(self.rng, seed): - num_rows = shape[-3] num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + num_low_freqs = int(round(num_cols * center_fraction)) + + mask = self.center_mask_func(num_cols, num_low_freqs) + + if return_acs: + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) # Create the mask - num_low_freqs = int(round(num_cols * center_fraction)) prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) - mask = self.rng.uniform(size=num_cols) < prob - pad = (num_cols - num_low_freqs + 1) // 2 - mask[pad : pad + num_low_freqs] = True + mask = mask | (self.rng.uniform(size=num_cols) < prob) - # Reshape the mask - mask_shape = [1 for _ in shape] - mask_shape[-2] = num_cols - mask = mask.reshape(*mask_shape).astype(np.int32) - mask_shape[-3] = num_rows + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) - mask = np.broadcast_to(mask, mask_shape)[np.newaxis, ...].copy() # Add coil axis, make array writable. - # TODO: Think about making this more efficient. - if return_acs: - acs_mask = np.zeros_like(mask) - acs_mask[:, :, pad : pad + num_low_freqs, ...] = 1 - return torch.from_numpy(acs_mask) +class FastMRIEquispacedMaskFunc(FastMRIMaskFunc): + r"""Equispaced vertical line mask function. - return torch.from_numpy(mask) + :class:`FastMRIEquispacedMaskFunc` creates a sub-sampling mask of given shape. The mask selects a subset of columns + from the input k-space data. If the k-space data has N columns, the mask picks out: + #. :math:`N_{\text{low freqs}} = (N \times \text{center_fraction})` columns in the center corresponding + to low-frequencies. + #. The other columns are selected with equal spacing at a proportion that reaches the desired acceleration + rate taking into consideration the number of low frequencies. This ensures that the expected number of + columns selected is equal to :math:`\frac{N}{\text{acceleration}}`. + + It is possible to use multiple center_fractions and accelerations, in which case one possible + (center_fraction, acceleration) is chosen uniformly at random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in + https://github.com/facebookresearch/fastMRI/issues/54), which will require modifications to standard GRAPPA + approaches. Nonetheless, this aspect of the function has been preserved to match the public multicoil data. + """ -class FastMRIEquispacedMaskFunc(BaseMaskFunc): def __init__( self, - accelerations: Tuple[Number, ...], - center_fractions: Optional[Tuple[float, ...]] = None, + accelerations: Union[List[Number], Tuple[Number, ...]], + center_fractions: Optional[Union[List[float], Tuple[float, ...]]] = None, uniform_range: bool = False, ): super().__init__( @@ -201,35 +273,24 @@ def __init__( uniform_range=uniform_range, ) - def mask_func(self, shape, return_acs=False, seed=None): - r"""Creates equispaced vertical line mask. - - FastMRIEquispacedMaskFunc creates a sub-sampling mask of a given shape. The mask selects a subset of columns - from the input k-space data. If the k-space data has N columns, the mask picks out: - - #. :math:`N_{\text{low freqs}} = (N \times \text{center_fraction})` columns in the center corresponding - to low-frequencies. - #. The other columns are selected with equal spacing at a proportion that reaches the desired acceleration - rate taking into consideration the number of low frequencies. This ensures that the expected number of - columns selected is equal to :math:`\frac{N}{\text{acceleration}}`. - - It is possible to use multiple center_fractions and accelerations, in which case one possible - (center_fraction, acceleration) is chosen uniformly at random each time the EquispacedMaskFunc object is called. - - Note that this function may not give equispaced samples (documented in - https://github.com/facebookresearch/fastMRI/issues/54), which will require modifications to standard GRAPPA - approaches. Nonetheless, this aspect of the function has been preserved to match the public multicoil data. + def mask_func( + self, + shape: Union[List[int], Tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: + """Creates an vertical equispaced vertical line mask. Parameters ---------- - shape: iterable[int] + shape: list or tuple of ints The shape of the mask to be created. The shape should at least 3 dimensions. Samples are drawn along the second last dimension. - seed: int (optional) - Seed for the random number generator. Setting the seed ensures the same mask is generated - each time for the same shape. return_acs: bool Return the autocalibration signal region as a mask. + seed: int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. Returns ------- @@ -240,15 +301,15 @@ def mask_func(self, shape, return_acs=False, seed=None): raise ValueError("Shape should have 3 or more dimensions") with temp_seed(self.rng, seed): - center_fraction, acceleration = self.choose_acceleration() num_cols = shape[-2] - num_rows = shape[-3] + + center_fraction, acceleration = self.choose_acceleration() num_low_freqs = int(round(num_cols * center_fraction)) - # create the mask - mask = np.zeros(num_cols, dtype=np.float32) - pad = (num_cols - num_low_freqs + 1) // 2 - mask[pad : pad + num_low_freqs] = True + mask = self.center_mask_func(num_cols, num_low_freqs) + + if return_acs: + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) # determine acceleration rate by adjusting for the number of low frequencies adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) @@ -258,20 +319,105 @@ def mask_func(self, shape, return_acs=False, seed=None): accel_samples = np.around(accel_samples).astype(np.uint) mask[accel_samples] = True - # Reshape the mask - mask_shape = [1 for _ in shape] - mask_shape[-2] = num_cols - mask = mask.reshape(*mask_shape).astype(np.int32) - mask_shape[-3] = num_rows + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) + + +class FastMRIMagicMaskFunc(FastMRIMaskFunc): + """Vertical line mask function as implemented in [1]_. + + :class:`FastMRIMagicMaskFunc` exploits the conjugate symmetry via offset-sampling. It is essentially an + equispaced mask with an offset for the opposite site of the k-space. Since MRI images often exhibit approximate + conjugate k-space symmetry, this mask is generally more efficient than :class:`FastMRIEquispacedMaskFunc`. + + References + ---------- + .. [1] Defazio, Aaron. “Offset Sampling Improves Deep Learning Based Accelerated MRI Reconstructions by + Exploiting Symmetry.” ArXiv:1912.01101 [Cs, Eess], Feb. 2020. arXiv.org, http://arxiv.org/abs/1912.01101. + """ + + def __init__( + self, + accelerations: Union[List[Number], Tuple[Number, ...]], + center_fractions: Optional[Union[List[float], Tuple[float, ...]]] = None, + uniform_range: bool = False, + ): + super().__init__( + accelerations=accelerations, + center_fractions=center_fractions, + uniform_range=uniform_range, + ) + + def mask_func( + self, + shape: Union[List[int], Tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: + r"""Creates a vertical equispaced mask that exploits conjugate symmetry. + + + Parameters + ---------- + shape: list or tuple of ints + The shape of the mask to be created. The shape should at least 3 dimensions. + Samples are drawn along the second last dimension. + return_acs: bool + Return the autocalibration signal region as a mask. + seed: int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. + + Returns + ------- + mask: torch.Tensor + The sampling mask. + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_cols = shape[-2] + + center_fraction, acceleration = self.choose_acceleration() + + num_low_freqs = int(round(num_cols * center_fraction)) + # bound the number of low frequencies between 1 and target columns + target_cols_to_sample = int(round(num_cols / acceleration)) + num_low_freqs = max(min(num_low_freqs, target_cols_to_sample), 1) - mask = np.broadcast_to(mask, mask_shape)[np.newaxis, ...].copy() # Add coil axis, make array writable. + acs_mask = self.center_mask_func(num_cols, num_low_freqs) if return_acs: - acs_mask = np.zeros_like(mask) - acs_mask[:, :, pad : pad + num_low_freqs, ...] = 1 - return torch.from_numpy(acs_mask) + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, acs_mask)) + + # adjust acceleration rate based on target acceleration. + adjusted_target_cols_to_sample = target_cols_to_sample - num_low_freqs + adjusted_acceleration = 0 + if adjusted_target_cols_to_sample > 0: + adjusted_acceleration = int(round(num_cols / adjusted_target_cols_to_sample)) + + offset = self.rng.randint(0, high=adjusted_acceleration) + + if offset % 2 == 0: + offset_pos = offset + 1 + offset_neg = offset + 2 + else: + offset_pos = offset - 1 + 3 + offset_neg = offset - 1 + 0 + + poslen = (num_cols + 1) // 2 + neglen = num_cols - (num_cols + 1) // 2 + mask_positive = np.zeros(poslen, dtype=bool) + mask_negative = np.zeros(neglen, dtype=bool) + + mask_positive[offset_pos::adjusted_acceleration] = True + mask_negative[offset_neg::adjusted_acceleration] = True + mask_negative = np.flip(mask_negative) + + mask = np.fft.fftshift(np.concatenate((mask_positive, mask_negative))) + mask = mask | acs_mask - return torch.from_numpy(mask) + return torch.from_numpy(self._reshape_and_broadcast_mask(shape, mask)) class CalgaryCampinasMaskFunc(BaseMaskFunc): @@ -286,7 +432,8 @@ class CalgaryCampinasMaskFunc(BaseMaskFunc): } # TODO: Configuration improvements, so no **kwargs needed. - def __init__(self, accelerations: Tuple[int, ...], **kwargs): # noqa + # pylint: disable=unused-argument + def __init__(self, accelerations: Union[List[Number], Tuple[Number, ...]], **kwargs): # noqa super().__init__(accelerations=accelerations, uniform_range=False) if not all(_ in [5, 10] for _ in accelerations): @@ -306,22 +453,26 @@ def circular_centered_mask(shape, radius): mask = ((dist_from_center <= radius) * np.ones(shape)).astype(bool) return mask[np.newaxis, ..., np.newaxis] - def mask_func(self, shape, return_acs=False, seed=None): + def mask_func( + self, + shape: Union[List[int], Tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: r"""Downloads and loads pre-computed Poisson masks. Currently supports shapes of :math`218 \times 170/174/180` and acceleration factors of `5` or `10`. Parameters ---------- - - shape: iterable[int] + shape: list or tuple of ints The shape of the mask to be created. The shape should at least 3 dimensions. Samples are drawn along the second last dimension. - seed: int (optional) - Seed for the random number generator. Setting the seed ensures the same mask is generated - each time for the same shape. return_acs: bool Return the autocalibration signal region as a mask. + seed: int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. Returns ------- @@ -387,7 +538,7 @@ class CIRCUSMaskFunc(BaseMaskFunc): def __init__( self, - accelerations, + accelerations: Union[List[Number], Tuple[Number, ...]], subsampling_scheme: CIRCUSSamplingMode, **kwargs, ): @@ -527,7 +678,30 @@ def circular_centered_mask(mask, eps=0.1): return intersection radius += eps - def mask_func(self, shape, return_acs=False, seed=None): + def mask_func( + self, + shape: Union[List[int], Tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: + """Produces :class:`CIRCUSMaskFunc` sampling masks. + + Parameters + ---------- + shape: list or tuple of ints + The shape of the mask to be created. The shape should at least 3 dimensions. + Samples are drawn along the second last dimension. + return_acs: bool + Return the autocalibration signal region as a mask. + seed: int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. + + Returns + ------- + mask: torch.Tensor + The sampling mask. + """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions") @@ -559,7 +733,7 @@ class RadialMaskFunc(CIRCUSMaskFunc): def __init__( self, - accelerations, + accelerations: Union[List[Number], Tuple[Number, ...]], **kwargs, ): super().__init__( @@ -574,7 +748,7 @@ class SpiralMaskFunc(CIRCUSMaskFunc): def __init__( self, - accelerations, + accelerations: Union[List[Number], Tuple[Number, ...]], **kwargs, ): super().__init__( @@ -584,6 +758,202 @@ def __init__( ) +class VariableDensityPoissonMaskFunc(BaseMaskFunc): + """Variable Density Poisson sampling mask function. Based on [1]_. + + Notes + ----- + + * Code inspired and modified from [2]_ with BSD-3 licence, Copyright (c) 2016, Frank Ong, Copyright (c) 2016, + The Regents of the University of California [3]_. + + References + ---------- + + .. [1] Bridson, Robert. “Fast Poisson Disk Sampling in Arbitrary Dimensions.” ACM SIGGRAPH 2007 + Sketches on - SIGGRAPH ’07, ACM Press, 2007, pp. 22-es. DOI.org (Crossref), + https://doi.org/10.1145/1278780.1278807. + .. [2] https://github.com/mikgroup/sigpy/blob/1817ff849d34d7cbbbcb503a1b310e7d8f95c242/sigpy/mri/samp.py#L11 + .. [3] https://github.com/mikgroup/sigpy/blob/master/LICENSE + + """ + + def __init__( + self, + accelerations: Union[List[Number], Tuple[Number, ...]], + center_scales: Union[List[float], Tuple[float, ...]], + crop_corner: Optional[bool] = False, + max_attempts: Optional[int] = 10, + tol: Optional[float] = 0.2, + slopes: Optional[Union[List[float], Tuple[float, ...]]] = None, + ): + """Inits :class:`VariableDensityPoissonMaskFunc`. + + Parameters + ---------- + accelerations: list or tuple of positive numbers + Amount of under-sampling. + center_scales: list or tuple of floats + Must have the same lenght as `accelerations`. Amount of center fully-sampling. + For center_scale='r', then a centered disk area with radius equal to + :math:`R = \sqrt{{n_r}^2 + {n_c}^2} \times r` will be fully sampled, where :math:`n_r` and :math:`n_c` + denote the input shape. + crop_corner: bool, optional + If True mask will be disk. Default: False. + max_attempts: int, optional + Maximum rejection samples. Default: 10. + tol: float, optional + Maximum deviation between the generated mask acceleration and the desired acceleration. Default: 0.2. + slopes: Optional[Union[List[float], Tuple[float, ...]]] + An increasing sequence of non-negative floats (of length 2) to be used + for the generation of the sampling radius. Default: None. + """ + super().__init__( + accelerations=accelerations, + center_fractions=center_scales, + uniform_range=False, + ) + self.crop_corner = crop_corner + self.max_attempts = max_attempts + self.tol = tol + if slopes is not None: + assert ( + slopes[0] >= 0 and slopes[0] < slopes[1] and len(slopes) == 2 + ), f"`slopes` must be an increasing sequence of two non-negative floats. Received {slopes}." + self.slopes = slopes + + def mask_func( + self, + shape: Union[List[int], Tuple[int, ...]], + return_acs: bool = False, + seed: Optional[Union[int, Iterable[int]]] = None, + ) -> torch.Tensor: + """Produces variable Density Poisson sampling masks. + + Parameters + ---------- + shape: list or tuple of ints + The shape of the mask to be created. The shape should at least 3 dimensions. + Samples are drawn along the second last dimension. + return_acs: bool + Return the autocalibration signal region as a mask. + seed: int or iterable of ints or None (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. Default: None. + + Returns + ------- + mask: torch.Tensor + The sampling mask of shape (1, shape[0], shape[1], 1). + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + num_rows, num_cols = shape[:2] + + with temp_seed(self.rng, seed): + center_fraction, acceleration = self.choose_acceleration() + if seed is None: + # cython requires specific seed type so it cannot be None + cython_seed = 0 + elif isinstance(seed, (tuple, list)): + # cython `srand` method takes only integers + cython_seed = int(np.mean(seed)) + elif isinstance(seed, int): + cython_seed = seed + + if return_acs: + return torch.from_numpy( + self.centered_disk_mask((num_rows, num_cols), center_fraction)[np.newaxis, ..., np.newaxis] + ) + mask = self.poisson(num_rows, num_cols, center_fraction, acceleration, cython_seed) + return torch.from_numpy(mask[np.newaxis, ..., np.newaxis]) + + def poisson( + self, + num_rows: int, + num_cols: int, + center_fraction: float, + acceleration: float, + seed: int = 0, + ) -> torch.Tensor: + """Calculates mask by calling the cython `_poisson` method. + + Parameters + ---------- + num_rows: int + Number of rows - x-axis size. + num_cols: int + Number of columns - y-axis size. + center_fraction: float + Amount of center fully-sampling. + acceleration: float + Acceleration factor. + seed: int + Seed to be used by cython function. Default: 0. + + Returns + ------- + mask: torch.Tensor + Sampling mask of shape (`num_rows`, `num_cols`). + """ + # pylint: disable=too-many-locals + x, y = np.mgrid[:num_rows, :num_cols] + + x = np.maximum(abs(x - num_rows / 2), 0) + x /= x.max() + y = np.maximum(abs(y - num_cols / 2), 0) + y /= y.max() + r = np.sqrt(x**2 + y**2) + + if self.slopes is not None: + slope_min, slope_max = self.slopes + else: + slope_min, slope_max = 0, max(num_rows, num_cols) + + while slope_min < slope_max: + slope = (slope_max + slope_min) / 2 + radius_x = np.clip((1 + r * slope) * num_rows / max(num_rows, num_cols), 1, None) + + radius_y = np.clip((1 + r * slope) * num_cols / max(num_rows, num_cols), 1, None) + + mask = np.zeros((num_rows, num_cols), dtype=int) + + _poisson(num_rows, num_cols, self.max_attempts, mask, radius_x, radius_y, seed) + + mask = mask | self.centered_disk_mask((num_rows, num_cols), center_fraction) + + if self.crop_corner: + mask *= r < 1 + + actual_acceleration = num_rows * num_cols / mask.sum() + + if abs(actual_acceleration - acceleration) < self.tol: + break + if actual_acceleration < acceleration: + slope_min = slope + else: + slope_max = slope + + if abs(actual_acceleration - acceleration) >= self.tol: + raise ValueError(f"Cannot generate mask to satisfy accel={acceleration}.") + + return mask + + @staticmethod + def centered_disk_mask(shape, center_scale): + center_x = shape[0] // 2 + center_y = shape[1] // 2 + + X, Y = np.indices(shape) + + # r = sqrt( center_scale * H * W / pi) + radius = int(np.sqrt(np.prod(shape) * center_scale / np.pi)) + + mask = ((X - center_x) ** 2 + (Y - center_y) ** 2) < radius**2 + + return mask.astype(int) + + class DictionaryMaskFunc(BaseMaskFunc): def __init__(self, data_dictionary, **kwargs): # noqa super().__init__(accelerations=None) diff --git a/direct/nn/crossdomain/crossdomain.py b/direct/nn/crossdomain/crossdomain.py index 3123cc9f..39fc4259 100644 --- a/direct/nn/crossdomain/crossdomain.py +++ b/direct/nn/crossdomain/crossdomain.py @@ -148,7 +148,7 @@ def _backward_operator( sampling_mask == 0, torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), kspace, - ), + ).contiguous(), self._spatial_dims, ), sensitivity_map, diff --git a/direct/nn/kikinet/kikinet.py b/direct/nn/kikinet/kikinet.py index 8a54bea6..ab985e40 100644 --- a/direct/nn/kikinet/kikinet.py +++ b/direct/nn/kikinet/kikinet.py @@ -160,7 +160,7 @@ def forward( sampling_mask == 0, torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), kspace, - ), + ).contiguous(), self._spatial_dims, ), sensitivity_map, diff --git a/direct/nn/lpd/lpd.py b/direct/nn/lpd/lpd.py index 9760be80..b41cb460 100644 --- a/direct/nn/lpd/lpd.py +++ b/direct/nn/lpd/lpd.py @@ -242,7 +242,7 @@ def _backward_operator( sampling_mask == 0, torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), kspace, - ), + ).contiguous(), self._spatial_dims, ), sensitivity_map, diff --git a/installation.rst b/installation.rst index f18eec7f..603b6e03 100644 --- a/installation.rst +++ b/installation.rst @@ -40,7 +40,7 @@ Install using ``conda`` .. code-block:: - pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 + pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116 **otherwise**\ , install the latest PyTorch CPU version (not recommended): @@ -55,6 +55,12 @@ Install using ``conda`` python3 setup.py install + or + + .. code-block:: + + python3 -m pip install -e ".[dev]" + This will install ``direct`` as a python module. Common Installation Issues diff --git a/setup.py b/setup.py index 18b1275a..d36ab996 100644 --- a/setup.py +++ b/setup.py @@ -3,8 +3,25 @@ """The setup script.""" import ast +import pathlib + +from setuptools import Extension, find_packages, setup # type: ignore +from setuptools.command.build_ext import build_ext + + +class _build_ext(build_ext): + def run(self): + import numpy as np + + self.include_dirs.append(np.get_include()) + super().run() + + def finalize_options(self): + from Cython.Build import cythonize + + self.distribution.ext_modules = cythonize(self.distribution.ext_modules) + super().finalize_options() -from setuptools import find_packages, setup # type: ignore with open("direct/__init__.py") as f: for line in f: @@ -35,6 +52,7 @@ "direct=direct.cli:main", ], }, + setup_requires=["numpy", "cython"], install_requires=[ "numpy>=1.21.2", "h5py==3.3.0", @@ -71,4 +89,8 @@ url="https://github.com/NKI-AI/direct", version=version, zip_safe=False, + cmdclass={"build_ext": _build_ext}, + ext_modules=[ + Extension("direct.common._poisson", sources=[str(pathlib.Path(".") / "direct" / "common" / "_poisson.pyx")]) + ], ) diff --git a/tests/tests_common/test_subsample.py b/tests/tests_common/test_subsample.py index 7dfb1f48..8fae1036 100644 --- a/tests/tests_common/test_subsample.py +++ b/tests/tests_common/test_subsample.py @@ -11,9 +11,20 @@ import pytest import torch -from direct.common.subsample import FastMRIRandomMaskFunc, RadialMaskFunc, SpiralMaskFunc +from direct.common.subsample import ( + FastMRIEquispacedMaskFunc, + FastMRIMagicMaskFunc, + FastMRIRandomMaskFunc, + RadialMaskFunc, + SpiralMaskFunc, + VariableDensityPoissonMaskFunc, +) +@pytest.mark.parametrize( + "mask_func", + [FastMRIRandomMaskFunc, FastMRIEquispacedMaskFunc, FastMRIMagicMaskFunc], +) @pytest.mark.parametrize( "center_fracs, accelerations, batch_size, dim", [ @@ -21,8 +32,8 @@ ([0.2, 0.4], [4, 8], 2, 368), ], ) -def test_fastmri_random_mask_reuse(center_fracs, accelerations, batch_size, dim): - mask_func = FastMRIRandomMaskFunc(center_fracs, accelerations) +def test_fastmri_mask_reuse(mask_func, center_fracs, accelerations, batch_size, dim): + mask_func = mask_func(center_fractions=center_fracs, accelerations=accelerations) shape = (batch_size, dim, dim, 2) mask1 = mask_func(shape, seed=123) mask2 = mask_func(shape, seed=123) @@ -31,6 +42,10 @@ def test_fastmri_random_mask_reuse(center_fracs, accelerations, batch_size, dim) assert torch.all(mask2 == mask3) +@pytest.mark.parametrize( + "mask_func", + [FastMRIRandomMaskFunc, FastMRIEquispacedMaskFunc, FastMRIMagicMaskFunc], +) @pytest.mark.parametrize( "center_fracs, accelerations, batch_size, dim", [ @@ -38,8 +53,8 @@ def test_fastmri_random_mask_reuse(center_fracs, accelerations, batch_size, dim) ([0.2, 0.4], [4, 8], 2, 368), ], ) -def test_fastmri_random_mask_low_freqs(center_fracs, accelerations, batch_size, dim): - mask_func = FastMRIRandomMaskFunc(center_fracs, accelerations) +def test_fastmri_mask_low_freqs(mask_func, center_fracs, accelerations, batch_size, dim): + mask_func = mask_func(center_fractions=center_fracs, accelerations=accelerations) shape = (batch_size, dim, dim, 2) mask = mask_func(shape, seed=123) mask_shape = [1] * (len(shape) + 1) @@ -57,6 +72,10 @@ def test_fastmri_random_mask_low_freqs(center_fracs, accelerations, batch_size, assert num_low_freqs_matched +@pytest.mark.parametrize( + "mask_func", + [FastMRIRandomMaskFunc, FastMRIEquispacedMaskFunc, FastMRIMagicMaskFunc], +) @pytest.mark.parametrize( "shape, center_fractions, accelerations", [ @@ -64,12 +83,8 @@ def test_fastmri_random_mask_low_freqs(center_fracs, accelerations, batch_size, ([2, 64, 64, 2], [0.04, 0.08], [8, 4]), ], ) -def test_apply_mask_fastmri(shape, center_fractions, accelerations): - mask_func = FastMRIRandomMaskFunc( - center_fractions=center_fractions, - accelerations=accelerations, - uniform_range=False, - ) +def test_apply_mask_fastmri(mask_func, shape, center_fractions, accelerations): + mask_func = mask_func(center_fractions=center_fractions, accelerations=accelerations) mask = mask_func(shape[1:], seed=123) acs_mask = mask_func(shape[1:], seed=123, return_acs=True) expected_mask_shape = (1, shape[1], shape[2], 1) @@ -80,6 +95,10 @@ def test_apply_mask_fastmri(shape, center_fractions, accelerations): assert np.allclose(mask & acs_mask, acs_mask) +@pytest.mark.parametrize( + "mask_func", + [FastMRIRandomMaskFunc, FastMRIEquispacedMaskFunc, FastMRIMagicMaskFunc], +) @pytest.mark.parametrize( "shape, center_fractions, accelerations", [ @@ -87,12 +106,8 @@ def test_apply_mask_fastmri(shape, center_fractions, accelerations): ([2, 64, 64, 2], [0.04, 0.08], [8, 4]), ], ) -def test_same_across_volumes_mask_fastmri(shape, center_fractions, accelerations): - mask_func = FastMRIRandomMaskFunc( - center_fractions=center_fractions, - accelerations=accelerations, - uniform_range=False, - ) +def test_same_across_volumes_mask_fastmri(mask_func, shape, center_fractions, accelerations): + mask_func = mask_func(center_fractions=center_fractions, accelerations=accelerations) num_slices = shape[0] masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] @@ -173,3 +188,55 @@ def test_same_across_volumes_mask_spiral(shape, accelerations): masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(num_slices - 1)) + + +@pytest.mark.parametrize( + "shape, accelerations, center_scales", + [ + ([4, 32, 32, 2], [4], [0.08]), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), + ], +) +@pytest.mark.parametrize( + "seed", + [ + None, + 10, + 100, + 1000, + np.random.randint(0, 10000), + list(np.random.randint(0, 10000, 20)), + tuple(np.random.randint(100000, 1000000, 30)), + ], +) +def test_apply_mask_poisson(shape, accelerations, center_scales, seed): + mask_func = VariableDensityPoissonMaskFunc( + accelerations=accelerations, + center_scales=center_scales, + ) + mask = mask_func(shape[1:], seed=seed) + acs_mask = mask_func(shape[1:], seed=seed, return_acs=True) + expected_mask_shape = (1, shape[1], shape[2], 1) + assert mask.max() == 1 + assert mask.min() == 0 + assert mask.shape == expected_mask_shape + if seed is not None: + assert np.allclose(mask & acs_mask, acs_mask) + + +@pytest.mark.parametrize( + "shape, accelerations, center_scales", + [ + ([4, 32, 32, 2], [4], [0.08]), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), + ], +) +def test_same_across_volumes_mask_spiral(shape, accelerations, center_scales): + mask_func = VariableDensityPoissonMaskFunc( + accelerations=accelerations, + center_scales=center_scales, + ) + num_slices = shape[0] + masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] + + assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(num_slices - 1))