Skip to content

Commit

Permalink
Make tests work wth enums
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jun 5, 2024
1 parent c669be5 commit 1cbbb45
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 57 deletions.
36 changes: 18 additions & 18 deletions direct/common/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __call__(self, shape: tuple[int, ...], *args, **kwargs) -> torch.Tensor:
"""
if len(shape) < 3:
raise ValueError("Shape should have 3 or more dimensions.")
if self.mode != MaskFuncMode.STATIC and len(shape) < 4:
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] and len(shape) < 4:
raise ValueError("Shape should have 4 or more dimensions for dynamic or multislice mode.")

mask = self.mask_func(shape, *args, **kwargs)
Expand Down Expand Up @@ -310,7 +310,7 @@ def _reshape_and_broadcast_mask(self, shape: tuple[int, ...], mask: np.ndarray)
# Reshape the mask
mask_shape = [1 for _ in shape]
mask_shape[-2] = num_cols
if self.mode != MaskFuncMode.STATIC:
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
mask_shape[-4] = shape[-4]
mask = mask.reshape(*mask_shape).astype(bool)
mask_shape[-3] = num_rows
Expand Down Expand Up @@ -419,7 +419,7 @@ def mask_func(
The sampling mask.
"""
num_cols = shape[-2]
num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1
num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1

with temp_seed(self.rng, seed):

Expand All @@ -432,7 +432,7 @@ def mask_func(

mask = self.center_mask_func(num_cols, num_low_freqs)

if self.mode != MaskFuncMode.STATIC:
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0)

if return_acs:
Expand Down Expand Up @@ -681,7 +681,7 @@ def mask_func(
The sampling mask.
"""
num_cols = shape[-2]
num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1
num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1

with temp_seed(self.rng, seed):

Expand All @@ -694,7 +694,7 @@ def mask_func(

mask = self.center_mask_func(num_cols, num_low_freqs)

if self.mode != MaskFuncMode.STATIC:
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0)

if return_acs:
Expand Down Expand Up @@ -940,7 +940,7 @@ def mask_func(
The sampling mask.
"""
num_cols = shape[-2]
num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1
num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1

with temp_seed(self.rng, seed):

Expand All @@ -960,7 +960,7 @@ def mask_func(

acs_mask = self.center_mask_func(num_cols, num_low_freqs)

if self.mode != MaskFuncMode.STATIC:
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0)

if return_acs:
Expand Down Expand Up @@ -1590,7 +1590,7 @@ def mask_func(
num_rows = shape[-3]
num_cols = shape[-2]

num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1
num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1

with temp_seed(self.rng, seed):
center_fraction, acceleration = self.choose_acceleration()
Expand All @@ -1601,7 +1601,7 @@ def mask_func(
num_low_freqs * acceleration - num_rows * num_cols
)

if self.mode != MaskFuncMode.STATIC:
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0)

acs_mask = torch.from_numpy(reshape_array_to_shape(acs_mask, shape)[np.newaxis]).bool()
Expand Down Expand Up @@ -1862,7 +1862,7 @@ def mask_func(
The sampling mask of shape (1, shape[0], shape[1], 1).
"""
num_rows, num_cols = shape[-3:-1]
num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1
num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1

with temp_seed(self.rng, seed):
self.rng.seed(integerize_seed(seed))
Expand All @@ -1871,7 +1871,7 @@ def mask_func(

if return_acs:
acs_mask = centered_disk_mask((num_rows, num_cols), center_fraction)
if self.mode != MaskFuncMode.STATIC:
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0)
return torch.from_numpy(reshape_array_to_shape(acs_mask, shape)[np.newaxis]).bool()

Expand Down Expand Up @@ -2036,7 +2036,7 @@ def mask_func(
"""

num_cols = shape[-2]
num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1
num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1

with temp_seed(self.rng, seed):
self.rng.seed(integerize_seed(seed))
Expand All @@ -2046,7 +2046,7 @@ def mask_func(

mask = self.center_mask_func(num_cols, num_low_freqs).astype(int)

if self.mode != MaskFuncMode.STATIC:
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0)

if return_acs:
Expand All @@ -2055,7 +2055,7 @@ def mask_func(
# Calls cython function
nonzero_count = int(np.round(num_cols / acceleration - num_low_freqs - 1))

if self.mode != MaskFuncMode.STATIC:
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
for i in range(num_slc_or_time):
gaussian_mask_1d(
nonzero_count,
Expand Down Expand Up @@ -2154,7 +2154,7 @@ def mask_func(
The sampling mask.
"""
num_rows, num_cols = shape[-3:-1]
num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1
num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1

with temp_seed(self.rng, seed):
self.rng.seed(integerize_seed(seed))
Expand All @@ -2163,15 +2163,15 @@ def mask_func(

mask = centered_disk_mask((num_rows, num_cols), center_fraction)

if self.mode != MaskFuncMode.STATIC:
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0)

if return_acs:
return torch.from_numpy(reshape_array_to_shape(mask, shape)[np.newaxis]).bool()

std = 6 * np.array([np.sqrt(num_rows // 2), np.sqrt(num_cols // 2)], dtype=float)

if self.mode != MaskFuncMode.STATIC:
if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]:
for i in range(num_slc_or_time):
# Calls cython function
gaussian_mask_2d(
Expand Down
10 changes: 5 additions & 5 deletions direct/common/subsample_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
from omegaconf import MISSING

from direct.config.defaults import BaseConfig
from direct.types import MaskFuncMode, Number
from direct.types import MaskFuncMode


@dataclass
class MaskingConfig(BaseConfig):
name: str = MISSING
accelerations: tuple[Number, ...] = (5.0,)
center_fractions: Optional[tuple[Number, ...]] = (0.1,)
accelerations: tuple[float, ...] = (5.0,)
center_fractions: Optional[tuple[float, ...]] = (0.1,)
uniform_range: bool = False
mode: MaskFuncMode = MaskFuncMode.STATIC

val_accelerations: tuple[Number, ...] = (5.0, 10.0)
val_center_fractions: Optional[tuple[Number, ...]] = (0.1, 0.05)
val_accelerations: tuple[float, ...] = (5.0, 10.0)
val_center_fractions: Optional[tuple[float, ...]] = (0.1, 0.05)
58 changes: 25 additions & 33 deletions tests/tests_common/test_subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def test_mask_reuse(mask_func, center_fracs, accelerations, batch_size, dim):
],
)
@pytest.mark.parametrize(
"accelerations, batch_size, dim",
"center_fracs, accelerations, batch_size, dim",
[
([4], 4, 320),
([4, 8], 2, 368),
([0.2], [4], 4, 320),
([0.2, 0.4], [4, 8], 2, 368),
],
)
def test_mask_reuse_circus(mask_func, accelerations, batch_size, dim):
mask_func = mask_func(accelerations=accelerations)
def test_mask_reuse_circus(mask_func, center_fracs, accelerations, batch_size, dim):
mask_func = mask_func(accelerations=accelerations, center_fractions=center_fracs)
shape = (batch_size, dim, dim, 2)
mask1 = mask_func(shape, seed=123)
mask2 = mask_func(shape, seed=123)
Expand Down Expand Up @@ -232,16 +232,14 @@ def test_same_across_volumes_mask_calgary_campinas(shape, accelerations):


@pytest.mark.parametrize(
"shape, accelerations",
"shape, accelerations, center_fractions",
[
([4, 32, 32, 2], [4]),
([2, 64, 64, 2], [8, 4]),
([4, 32, 32, 2], [4], [0.08]),
([2, 64, 64, 2], [8, 4], [0.04, 0.08]),
],
)
def test_apply_mask_radial(shape, accelerations):
mask_func = RadialMaskFunc(
accelerations=accelerations,
)
def test_apply_mask_radial(shape, accelerations, center_fractions):
mask_func = RadialMaskFunc(accelerations=accelerations, center_fractions=center_fractions)
mask = mask_func(shape[1:], seed=123)
acs_mask = mask_func(shape[1:], seed=123, return_acs=True)
expected_mask_shape = (1, shape[1], shape[2], 1)
Expand All @@ -253,33 +251,29 @@ def test_apply_mask_radial(shape, accelerations):


@pytest.mark.parametrize(
"shape, accelerations",
"shape, accelerations, center_fractions",
[
([4, 32, 32, 2], [4]),
([2, 64, 64, 2], [8, 4]),
([4, 32, 32, 2], [4], [0.08]),
([2, 64, 64, 2], [8, 4], [0.04, 0.08]),
],
)
def test_same_across_volumes_mask_radial(shape, accelerations):
mask_func = RadialMaskFunc(
accelerations=accelerations,
)
def test_same_across_volumes_mask_radial(shape, accelerations, center_fractions):
mask_func = RadialMaskFunc(accelerations=accelerations, center_fractions=center_fractions)
num_slices = shape[0]
masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)]

assert all(np.allclose(masks[_], masks[_ + 1]) for _ in range(num_slices - 1))


@pytest.mark.parametrize(
"shape, accelerations",
"shape, accelerations, center_fractions",
[
([4, 32, 32, 2], [4]),
([2, 64, 64, 2], [8, 4]),
([4, 32, 32, 2], [4], [0.08]),
([2, 64, 64, 2], [8, 4], [0.04, 0.08]),
],
)
def test_apply_mask_spiral(shape, accelerations):
mask_func = SpiralMaskFunc(
accelerations=accelerations,
)
def test_apply_mask_spiral(shape, accelerations, center_fractions):
mask_func = SpiralMaskFunc(accelerations=accelerations, center_fractions=center_fractions)
mask = mask_func(shape[1:], seed=123)
acs_mask = mask_func(shape[1:], seed=123, return_acs=True)
expected_mask_shape = (1, shape[1], shape[2], 1)
Expand All @@ -291,16 +285,14 @@ def test_apply_mask_spiral(shape, accelerations):


@pytest.mark.parametrize(
"shape, accelerations",
"shape, accelerations, center_fractions",
[
([4, 32, 32, 2], [4]),
([2, 64, 64, 2], [8, 4]),
([4, 32, 32, 2], [4], [0.08]),
([2, 64, 64, 2], [8, 4], [0.04, 0.08]),
],
)
def test_same_across_volumes_mask_spiral(shape, accelerations):
mask_func = SpiralMaskFunc(
accelerations=accelerations,
)
def test_same_across_volumes_mask_spiral(shape, accelerations, center_fractions):
mask_func = SpiralMaskFunc(accelerations=accelerations, center_fractions=center_fractions)
num_slices = shape[0]
masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)]

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_data/test_mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,4 +718,4 @@ def test_build_mri_transforms(shape, spatial_dims, estimate_body_coil_image, ima
mask_shape = torch.ones(len(shape) + 1).int().tolist()
mask_shape[-3] = shape[-2]
mask_shape[-2] = shape[-1]
assert list(sample["sampling_mask"].shape) == mask_shape
assert list(sample["sampling_mask"].shape) == mask_shape

0 comments on commit 1cbbb45

Please sign in to comment.