Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jun 5, 2024
1 parent 1e69351 commit c669be5
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions tests/tests_data/test_mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,23 @@ def create_sample(shape, **kwargs):


def _mask_func(shape, seed=None, return_acs=False):
num_rows, num_cols = shape[:2]
num_rows, num_cols = shape[-3:-1]
mask = torch.zeros(num_rows, num_cols).bool()
mask[
num_rows // 2 - num_rows // 4 : num_rows // 2 + num_rows // 4,
num_cols // 2 - num_cols // 4 : num_cols // 2 + num_cols // 4,
] = True
mask_shape = torch.ones(len(shape)).int().tolist()
mask_shape[-3] = num_rows
mask_shape[-2] = num_cols
if return_acs:
return mask.unsqueeze(0).unsqueeze(-1)
return mask.reshape(mask_shape).unsqueeze(0)
if seed:
rng = np.random.RandomState()
rng.seed(seed)
mask = mask | torch.from_numpy(np.random.rand(num_rows, num_cols)).round().bool()
return mask.unsqueeze(0).unsqueeze(-1)
mask = mask.reshape(mask_shape) | torch.from_numpy(np.random.rand(*mask_shape)).round().bool()

return mask.unsqueeze(0)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -164,9 +168,10 @@ def test_CreateSamplingMask(shape, return_acs, use_shape):
sample = create_sample(shape)

transform = CreateSamplingMask(
mask_func=_mask_func, shape=shape[-3:-1] if use_shape else None, return_acs=return_acs
mask_func=_mask_func, shape=shape[1:-1] if use_shape else None, return_acs=return_acs
)
sample = transform(sample)
print(sample["kspace"].shape, sample["sampling_mask"].shape)
assert "sampling_mask" in sample

mask_shape = torch.ones(len(shape))
Expand Down Expand Up @@ -713,4 +718,4 @@ def test_build_mri_transforms(shape, spatial_dims, estimate_body_coil_image, ima
mask_shape = torch.ones(len(shape) + 1).int().tolist()
mask_shape[-3] = shape[-2]
mask_shape[-2] = shape[-1]
assert list(sample["sampling_mask"].shape) == mask_shape
assert list(sample["sampling_mask"].shape) == mask_shape

0 comments on commit c669be5

Please sign in to comment.