Skip to content

Commit

Permalink
Adding crop (kspace) MRI transforms, reformatting MRI transforms (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis authored Jun 19, 2022
1 parent 7c9d593 commit 988d46a
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 93 deletions.
2 changes: 0 additions & 2 deletions direct/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@ def is_file(path):
def file_or_url(path: PathOrString) -> FileOrUrl:
if check_is_valid_url(path):
return FileOrUrl(path)

path = pathlib.Path(path)
if path.is_file():
return FileOrUrl(path)

raise argparse.ArgumentTypeError(f"{path} is not a valid file or url.")


Expand Down
3 changes: 1 addition & 2 deletions direct/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,7 @@ def _set_params(self, ellipsoids=None) -> None:

self.ellipsoids = ellipsoids

def sample_image(self, idx: int) -> np.ndarray:
# pylint: disable=too-many-locals
def sample_image(self, idx: int) -> np.ndarray: # pylint: disable=too-many-locals
# meshgrid does X, Y backwards
X, Y, Z = np.meshgrid(
np.linspace(-1, 1, self.ny),
Expand Down
1 change: 0 additions & 1 deletion direct/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def __init__(
lists=filenames_lists, files_root=filenames_lists_root, data_root=root
)
self.logger.info("Attempting to load %s filenames from list(s).", len(filenames))

else:
self.logger.info("Parsing directory %s for h5 files.", self.root)
filenames = list(self.root.glob("*.h5"))
Expand Down
175 changes: 104 additions & 71 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,30 +157,68 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return sample


class CropAndMask(DirectModule):
class ApplyMask(DirectModule):
"""Data Transformer for training MRI reconstruction models.
Crops and masks k-space using a sampling mask.
Masks the k-space using a sampling mask.
"""

def __init__(self) -> None:
"""Inits :class:`ApplyMask`."""
super().__init__()
self.logger = logging.getLogger(type(self).__name__)

def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Calls :class:`ApplyMask`.
This assumes that a `sampling_mask` is present in the sample.
Parameters
----------
sample: Dict[str, Any]
Dict sample containing key `kspace`.
Returns
-------
Dict[str, Any]
Sample with new key `masked_kspace`.
"""
kspace = sample["kspace"]

assert "sampling_mask" in sample, "Key 'sampling_mask' not found in sample."
sampling_mask = sample["sampling_mask"]

masked_kspace, _ = T.apply_mask(kspace, sampling_mask)
sample["masked_kspace"] = masked_kspace

return sample


class CropKspace(DirectModule):
"""Data Transformer for training MRI reconstruction models.
Crops the k-space by:
* It first projects the k-space to the image-domain via the backward operator,
* It crops the back-projected k-space to specified shape or key,
* It transforms the cropped back-projected k-space to the k-space domain via the forward operator.
"""

def __init__(
self,
crop: Union[None, Tuple[int, ...]],
use_seed: bool = True,
crop: Union[str, Tuple[int, ...], List[int]],
forward_operator: Callable = T.fft2,
backward_operator: Callable = T.ifft2,
image_space_center_crop: bool = False,
random_crop_sampler_type: Optional[str] = "uniform",
random_crop_sampler_use_seed: Optional[bool] = True,
random_crop_sampler_gaussian_sigma: Optional[List[float]] = None,
) -> None:
"""Inits :class:`CropAndMask`.
"""Inits :class:`CropKspace`.
Parameters
----------
crop: tuple of ints or None
Size to crop input_image to.
use_seed: bool
If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
the same mask every time. Default: True.
crop: tuple of ints or str
Shape to crop the input to or a string pointing to a crop key (e.g. `reconstruction_size`).
forward_operator: Callable
The forward operator, e.g. some form of FFT (centered or uncentered).
Default: :class:`direct.data.transforms.fft2`.
Expand All @@ -191,87 +229,71 @@ def __init__(
If set, the crop in the data will be taken in the center
random_crop_sampler_type: Optional[str]
If "uniform" the random cropping will be done by uniformly sampling `crop`, as opposed to `gaussian` which
will sample from a gaussian distribution. Default: "uniform".
will sample from a gaussian distribution. If `image_space_center_crop` is True, then this is ignored.
Default: "uniform".
random_crop_sampler_use_seed: bool
If true, a pseudo-random number based on the filename is computed so that every slice of the volume
is cropped the same way. Default: True.
random_crop_sampler_gaussian_sigma: Optional[List[float]]
Standard variance of the gaussian when `random_crop_sampler_type` is `gaussian`.
If `image_space_center_crop` is True, then this is ignored. Default: None.
"""
super().__init__()
self.logger = logging.getLogger(type(self).__name__)

self.use_seed = use_seed
self.image_space_center_crop = image_space_center_crop

if not (isinstance(crop, (Iterable, str))):
raise ValueError(
f"Invalid input for `crop`. Received {crop}. Can be a list of tuple of integers or a string."
)
self.crop = crop
self.random_crop_sampler_type = random_crop_sampler_type
if self.crop:
if self.image_space_center_crop:
self.crop_func = T.complex_center_crop
else:
self.crop_func = functools.partial(T.complex_random_crop, sampler=self.random_crop_sampler_type)

if image_space_center_crop:
self.crop_func = T.complex_center_crop
else:
self.crop_func = functools.partial(
T.complex_random_crop, sampler=random_crop_sampler_type, sigma=random_crop_sampler_gaussian_sigma
)
self.random_crop_sampler_use_seed = random_crop_sampler_use_seed

self.forward_operator = forward_operator
self.backward_operator = backward_operator

self.image_space_center_crop = image_space_center_crop

def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Calls :class:`CropAndMask`.
"""Calls :class:`CropKspace`.
Parameters
----------
sample: Dict[str, Any]
Dict sample.
Dict sample containing key `kspace`.
Returns
-------
Dict[str, Any]
Cropped and masked sample.
"""
# Shape (coil, [slice], height, width, complex=2)
kspace = sample["kspace"]

# Image-space croppable objects
croppable_images = ["sensitivity_map", "input_image"]
kspace = sample["kspace"] # shape (coil, height, width, complex=2)

# Shape (coil, [slice], height, width, complex=2) if not None
sensitivity_map = sample.get("sensitivity_map", None)
# Shape (1, [slice], height, width, 1)
sampling_mask = sample["sampling_mask"]
# Shape (coil, [slice], height, width, complex=2)
backprojected_kspace = self.backward_operator(kspace)

# TODO: Also create a kspace-like crop function
if self.crop:
backprojected_kspace = self.crop_func(
[backprojected_kspace],
self.crop,
contiguous=True,
)
# Compute new k-space for the cropped input_image
kspace = self.forward_operator(backprojected_kspace)
for key in croppable_images:
if key in sample:
sample[key] = self.crop_func(
[sample[key]],
self.crop,
contiguous=True,
)
# TODO(gy): This is not correct, since cropping is done in the image space.
sampling_mask = self.crop_func(
[sampling_mask],
self.crop,
contiguous=True,
backprojected_kspace = self.backward_operator(kspace) # shape (coil, height, width, complex=2)

if isinstance(self.crop, str):
assert self.crop in sample, f"Not found {self.crop} key in sample."
crop_shape = sample[self.crop][:2]
else:
crop_shape = self.crop

cropper_args = {"data_list": [backprojected_kspace], "crop_shape": crop_shape, "contiguous": False}
if not self.image_space_center_crop:
cropper_args["seed"] = (
None if not self.random_crop_sampler_use_seed else tuple(map(ord, str(sample["filename"])))
)
masked_kspace, sampling_mask = T.apply_mask(kspace, sampling_mask)
# Shape ([slice], height, width)
sample["target"] = T.root_sum_of_squares(backprojected_kspace, dim=0)
# Shape (coil, [slice], height, width, complex=2)
sample["masked_kspace"] = masked_kspace
# Shape (1, [slice], height, width, 1)
sample["sampling_mask"] = sampling_mask
# Shape (coil, [slice], height, width, complex=2)
sample["kspace"] = kspace # The cropped kspace
cropped_backprojected_kspace = self.crop_func(**cropper_args)

if sensitivity_map is not None:
sample["sensitivity_map"] = sensitivity_map
# Compute new k-space for the cropped_backprojected_kspace
# shape (coil, new_height, new_width, complex=2)
sample["kspace"] = self.forward_operator(cropped_backprojected_kspace) # The cropped kspace

return sample

Expand Down Expand Up @@ -858,11 +880,22 @@ def build_mri_transforms(
# TODO: Use seed

mri_transforms: List[Callable] = [ToTensor()]
if crop:
mri_transforms += [
CropKspace(
crop=crop,
forward_operator=forward_operator,
backward_operator=backward_operator,
image_space_center_crop=image_center_crop,
random_crop_sampler_type=crop_type,
random_crop_sampler_use_seed=use_seed,
)
]
if mask_func:
mri_transforms += [
CreateSamplingMask(
mask_func,
shape=crop,
shape=(None if (isinstance(crop, str)) else crop),
use_seed=use_seed,
return_acs=estimate_sensitivity_maps,
)
Expand All @@ -876,13 +909,13 @@ def build_mri_transforms(
gaussian_sigma=sensitivity_maps_gaussian,
),
DeleteKeys(keys=["acs_mask"]),
CropAndMask(
crop,
forward_operator=forward_operator,
ComputeImage(
kspace_key="kspace",
target_key="target",
backward_operator=backward_operator,
image_space_center_crop=image_center_crop,
random_crop_sampler_type=crop_type,
type_reconstruction="rss",
),
ApplyMask(),
]
if estimate_body_coil_image and mask_func is not None:
mri_transforms.append(EstimateBodyCoilImage(mask_func, backward_operator=backward_operator, use_seed=use_seed))
Expand Down
12 changes: 8 additions & 4 deletions direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch
import torch.fft
from numpy.typing import ArrayLike

from direct.data.bbox import crop_to_bbox
from direct.utils import ensure_list, is_complex_data, is_power_of_two
Expand Down Expand Up @@ -602,7 +603,7 @@ def center_crop(data: torch.Tensor, shape: Union[List[int], Tuple[int, ...]]) ->

def complex_center_crop(
data_list: Union[List[torch.Tensor], torch.Tensor],
shape: Union[List[int], Tuple[int, ...]],
crop_shape: Union[List[int], Tuple[int, ...]],
offset: int = 1,
contiguous: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
Expand All @@ -613,7 +614,7 @@ def complex_center_crop(
data_list: Union[List[torch.Tensor], torch.Tensor]
The complex input tensor to be center cropped. It should have at least 3 dimensions
and the cropping is applied along dimensions didx and didx+1 and the last dimensions should have a size of 2.
shape: List[int] or Tuple[int, ...]
crop_shape: List[int] or Tuple[int, ...]
The output shape. The shape should be smaller than the corresponding dimensions of data.
If one value is None, this is filled in by the image shape.
offset: int
Expand All @@ -634,7 +635,7 @@ def complex_center_crop(
bbox = [0] * ndim + image_shape

# Allow for False in crop directions
shape = [_ if _ else image_shape[idx + offset] for idx, _ in enumerate(shape)]
shape = [_ if _ else image_shape[idx + offset] for idx, _ in enumerate(crop_shape)]
for idx, _ in enumerate(shape):
bbox[idx + offset] = (image_shape[idx + offset] - shape[idx]) // 2
bbox[len(image_shape) + idx + offset] = shape[idx]
Expand Down Expand Up @@ -663,6 +664,7 @@ def complex_random_crop(
contiguous: bool = False,
sampler: str = "uniform",
sigma: Union[float, List[float], None] = None,
seed: Union[None, int, ArrayLike] = None,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Apply a random crop to the input data tensor or a list of complex.
Expand All @@ -681,6 +683,7 @@ def complex_random_crop(
Select the random indices from either a `uniform` or `gaussian` distribution (around the center)
sigma: float or list of float or None
Standard variance of the gaussian when sampler is `gaussian`. If not set will take 1/3th of image shape
seed: None, int or ArrayLike
Returns
-------
Expand Down Expand Up @@ -710,7 +713,8 @@ def complex_random_crop(
f"Bounding box limits have negative values, "
f"this is likely to data size being smaller than the crop size. Got {limits}"
)

if seed is not None:
np.random.seed(seed)
if sampler == "uniform":
lower_point = np.random.randint(0, limits + 1).tolist()
elif sampler == "gaussian":
Expand Down
1 change: 0 additions & 1 deletion direct/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def get_filenames_for_datasets_from_config(cfg, files_root: PathOrString, data_r
"""
if "filenames_lists" not in cfg:
return None

lists = cfg.filenames_lists
return get_filenames_for_datasets(lists, files_root, data_root)

Expand Down
Loading

0 comments on commit 988d46a

Please sign in to comment.