diff --git a/direct/common/subsample.py b/direct/common/subsample.py index 632c664c..e6e2e439 100644 --- a/direct/common/subsample.py +++ b/direct/common/subsample.py @@ -203,7 +203,7 @@ def __call__(self, shape: tuple[int, ...], *args, **kwargs) -> torch.Tensor: """ if len(shape) < 3: raise ValueError("Shape should have 3 or more dimensions.") - if self.mode != MaskFuncMode.STATIC and len(shape) < 4: + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] and len(shape) < 4: raise ValueError("Shape should have 4 or more dimensions for dynamic or multislice mode.") mask = self.mask_func(shape, *args, **kwargs) @@ -310,7 +310,7 @@ def _reshape_and_broadcast_mask(self, shape: tuple[int, ...], mask: np.ndarray) # Reshape the mask mask_shape = [1 for _ in shape] mask_shape[-2] = num_cols - if self.mode != MaskFuncMode.STATIC: + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: mask_shape[-4] = shape[-4] mask = mask.reshape(*mask_shape).astype(bool) mask_shape[-3] = num_rows @@ -419,7 +419,7 @@ def mask_func( The sampling mask. """ num_cols = shape[-2] - num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1 + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): @@ -432,7 +432,7 @@ def mask_func( mask = self.center_mask_func(num_cols, num_low_freqs) - if self.mode != MaskFuncMode.STATIC: + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) if return_acs: @@ -681,7 +681,7 @@ def mask_func( The sampling mask. """ num_cols = shape[-2] - num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1 + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): @@ -694,7 +694,7 @@ def mask_func( mask = self.center_mask_func(num_cols, num_low_freqs) - if self.mode != MaskFuncMode.STATIC: + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) if return_acs: @@ -940,7 +940,7 @@ def mask_func( The sampling mask. """ num_cols = shape[-2] - num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1 + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): @@ -960,7 +960,7 @@ def mask_func( acs_mask = self.center_mask_func(num_cols, num_low_freqs) - if self.mode != MaskFuncMode.STATIC: + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0) if return_acs: @@ -1590,7 +1590,7 @@ def mask_func( num_rows = shape[-3] num_cols = shape[-2] - num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1 + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): center_fraction, acceleration = self.choose_acceleration() @@ -1601,7 +1601,7 @@ def mask_func( num_low_freqs * acceleration - num_rows * num_cols ) - if self.mode != MaskFuncMode.STATIC: + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0) acs_mask = torch.from_numpy(reshape_array_to_shape(acs_mask, shape)[np.newaxis]).bool() @@ -1862,7 +1862,7 @@ def mask_func( The sampling mask of shape (1, shape[0], shape[1], 1). """ num_rows, num_cols = shape[-3:-1] - num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1 + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): self.rng.seed(integerize_seed(seed)) @@ -1871,7 +1871,7 @@ def mask_func( if return_acs: acs_mask = centered_disk_mask((num_rows, num_cols), center_fraction) - if self.mode != MaskFuncMode.STATIC: + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: acs_mask = acs_mask[np.newaxis].repeat(num_slc_or_time, axis=0) return torch.from_numpy(reshape_array_to_shape(acs_mask, shape)[np.newaxis]).bool() @@ -2036,7 +2036,7 @@ def mask_func( """ num_cols = shape[-2] - num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1 + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): self.rng.seed(integerize_seed(seed)) @@ -2046,7 +2046,7 @@ def mask_func( mask = self.center_mask_func(num_cols, num_low_freqs).astype(int) - if self.mode != MaskFuncMode.STATIC: + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) if return_acs: @@ -2055,7 +2055,7 @@ def mask_func( # Calls cython function nonzero_count = int(np.round(num_cols / acceleration - num_low_freqs - 1)) - if self.mode != MaskFuncMode.STATIC: + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: for i in range(num_slc_or_time): gaussian_mask_1d( nonzero_count, @@ -2154,7 +2154,7 @@ def mask_func( The sampling mask. """ num_rows, num_cols = shape[-3:-1] - num_slc_or_time = shape[-4] if self.mode != MaskFuncMode.STATIC else 1 + num_slc_or_time = shape[-4] if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE] else 1 with temp_seed(self.rng, seed): self.rng.seed(integerize_seed(seed)) @@ -2163,7 +2163,7 @@ def mask_func( mask = centered_disk_mask((num_rows, num_cols), center_fraction) - if self.mode != MaskFuncMode.STATIC: + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: mask = mask[np.newaxis].repeat(num_slc_or_time, axis=0) if return_acs: @@ -2171,7 +2171,7 @@ def mask_func( std = 6 * np.array([np.sqrt(num_rows // 2), np.sqrt(num_cols // 2)], dtype=float) - if self.mode != MaskFuncMode.STATIC: + if self.mode in [MaskFuncMode.DYNAMIC, MaskFuncMode.MULTISLICE]: for i in range(num_slc_or_time): # Calls cython function gaussian_mask_2d( diff --git a/direct/common/subsample_config.py b/direct/common/subsample_config.py index ca0a1622..89cccbb3 100644 --- a/direct/common/subsample_config.py +++ b/direct/common/subsample_config.py @@ -8,16 +8,16 @@ from omegaconf import MISSING from direct.config.defaults import BaseConfig -from direct.types import MaskFuncMode, Number +from direct.types import MaskFuncMode @dataclass class MaskingConfig(BaseConfig): name: str = MISSING - accelerations: tuple[Number, ...] = (5.0,) - center_fractions: Optional[tuple[Number, ...]] = (0.1,) + accelerations: tuple[float, ...] = (5.0,) + center_fractions: Optional[tuple[float, ...]] = (0.1,) uniform_range: bool = False mode: MaskFuncMode = MaskFuncMode.STATIC - val_accelerations: tuple[Number, ...] = (5.0, 10.0) - val_center_fractions: Optional[tuple[Number, ...]] = (0.1, 0.05) + val_accelerations: tuple[float, ...] = (5.0, 10.0) + val_center_fractions: Optional[tuple[float, ...]] = (0.1, 0.05) diff --git a/tests/tests_common/test_subsample.py b/tests/tests_common/test_subsample.py index 0c0ef860..552b21d4 100644 --- a/tests/tests_common/test_subsample.py +++ b/tests/tests_common/test_subsample.py @@ -50,14 +50,14 @@ def test_mask_reuse(mask_func, center_fracs, accelerations, batch_size, dim): ], ) @pytest.mark.parametrize( - "accelerations, batch_size, dim", + "center_fracs, accelerations, batch_size, dim", [ - ([4], 4, 320), - ([4, 8], 2, 368), + ([0.2], [4], 4, 320), + ([0.2, 0.4], [4, 8], 2, 368), ], ) -def test_mask_reuse_circus(mask_func, accelerations, batch_size, dim): - mask_func = mask_func(accelerations=accelerations) +def test_mask_reuse_circus(mask_func, center_fracs, accelerations, batch_size, dim): + mask_func = mask_func(accelerations=accelerations, center_fractions=center_fracs) shape = (batch_size, dim, dim, 2) mask1 = mask_func(shape, seed=123) mask2 = mask_func(shape, seed=123) @@ -232,16 +232,14 @@ def test_same_across_volumes_mask_calgary_campinas(shape, accelerations): @pytest.mark.parametrize( - "shape, accelerations", + "shape, accelerations, center_fractions", [ - ([4, 32, 32, 2], [4]), - ([2, 64, 64, 2], [8, 4]), + ([4, 32, 32, 2], [4], [0.08]), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), ], ) -def test_apply_mask_radial(shape, accelerations): - mask_func = RadialMaskFunc( - accelerations=accelerations, - ) +def test_apply_mask_radial(shape, accelerations, center_fractions): + mask_func = RadialMaskFunc(accelerations=accelerations, center_fractions=center_fractions) mask = mask_func(shape[1:], seed=123) acs_mask = mask_func(shape[1:], seed=123, return_acs=True) expected_mask_shape = (1, shape[1], shape[2], 1) @@ -253,16 +251,14 @@ def test_apply_mask_radial(shape, accelerations): @pytest.mark.parametrize( - "shape, accelerations", + "shape, accelerations, center_fractions", [ - ([4, 32, 32, 2], [4]), - ([2, 64, 64, 2], [8, 4]), + ([4, 32, 32, 2], [4], [0.08]), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), ], ) -def test_same_across_volumes_mask_radial(shape, accelerations): - mask_func = RadialMaskFunc( - accelerations=accelerations, - ) +def test_same_across_volumes_mask_radial(shape, accelerations, center_fractions): + mask_func = RadialMaskFunc(accelerations=accelerations, center_fractions=center_fractions) num_slices = shape[0] masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] @@ -270,16 +266,14 @@ def test_same_across_volumes_mask_radial(shape, accelerations): @pytest.mark.parametrize( - "shape, accelerations", + "shape, accelerations, center_fractions", [ - ([4, 32, 32, 2], [4]), - ([2, 64, 64, 2], [8, 4]), + ([4, 32, 32, 2], [4], [0.08]), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), ], ) -def test_apply_mask_spiral(shape, accelerations): - mask_func = SpiralMaskFunc( - accelerations=accelerations, - ) +def test_apply_mask_spiral(shape, accelerations, center_fractions): + mask_func = SpiralMaskFunc(accelerations=accelerations, center_fractions=center_fractions) mask = mask_func(shape[1:], seed=123) acs_mask = mask_func(shape[1:], seed=123, return_acs=True) expected_mask_shape = (1, shape[1], shape[2], 1) @@ -291,16 +285,14 @@ def test_apply_mask_spiral(shape, accelerations): @pytest.mark.parametrize( - "shape, accelerations", + "shape, accelerations, center_fractions", [ - ([4, 32, 32, 2], [4]), - ([2, 64, 64, 2], [8, 4]), + ([4, 32, 32, 2], [4], [0.08]), + ([2, 64, 64, 2], [8, 4], [0.04, 0.08]), ], ) -def test_same_across_volumes_mask_spiral(shape, accelerations): - mask_func = SpiralMaskFunc( - accelerations=accelerations, - ) +def test_same_across_volumes_mask_spiral(shape, accelerations, center_fractions): + mask_func = SpiralMaskFunc(accelerations=accelerations, center_fractions=center_fractions) num_slices = shape[0] masks = [mask_func(shape[1:], seed=123) for _ in range(num_slices)] diff --git a/tests/tests_data/test_mri_transforms.py b/tests/tests_data/test_mri_transforms.py index d92e89f5..ae4b8ba4 100644 --- a/tests/tests_data/test_mri_transforms.py +++ b/tests/tests_data/test_mri_transforms.py @@ -718,4 +718,4 @@ def test_build_mri_transforms(shape, spatial_dims, estimate_body_coil_image, ima mask_shape = torch.ones(len(shape) + 1).int().tolist() mask_shape[-3] = shape[-2] mask_shape[-2] = shape[-1] - assert list(sample["sampling_mask"].shape) == mask_shape \ No newline at end of file + assert list(sample["sampling_mask"].shape) == mask_shape