Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crop kspace transforms #210

Merged
merged 9 commits into from
Jun 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
.. image:: https://joss.theoj.org/papers/10.21105/joss.04278/status.svg
:target: https://doi.org/10.21105/joss.04278

.. image:: https://github.com/NKI-AI/direct/actions/workflows/tox.yml/badge.svg
:target: https://github.com/NKI-AI/direct/actions/workflows/tox.yml
:alt: tox
Expand Down
2 changes: 1 addition & 1 deletion direct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Copyright (c) DIRECT Contributors

__author__ = """direct contributors"""
__version__ = "1.0.2"
georgeyiasemis marked this conversation as resolved.
Show resolved Hide resolved
__version__ = "1.0.3"
2 changes: 0 additions & 2 deletions direct/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@ def is_file(path):
def file_or_url(path: PathOrString) -> FileOrUrl:
if check_is_valid_url(path):
return FileOrUrl(path)

path = pathlib.Path(path)
if path.is_file():
return FileOrUrl(path)

raise argparse.ArgumentTypeError(f"{path} is not a valid file or url.")


Expand Down
3 changes: 1 addition & 2 deletions direct/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,7 @@ def _set_params(self, ellipsoids=None) -> None:

self.ellipsoids = ellipsoids

def sample_image(self, idx: int) -> np.ndarray:
# pylint: disable=too-many-locals
def sample_image(self, idx: int) -> np.ndarray: # pylint: disable=too-many-locals
# meshgrid does X, Y backwards
X, Y, Z = np.meshgrid(
np.linspace(-1, 1, self.ny),
Expand Down
1 change: 0 additions & 1 deletion direct/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def __init__(
lists=filenames_lists, files_root=filenames_lists_root, data_root=root
)
self.logger.info("Attempting to load %s filenames from list(s).", len(filenames))

else:
self.logger.info("Parsing directory %s for h5 files.", self.root)
filenames = list(self.root.glob("*.h5"))
Expand Down
175 changes: 104 additions & 71 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,30 +157,68 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return sample


class CropAndMask(DirectModule):
class ApplyMask(DirectModule):
"""Data Transformer for training MRI reconstruction models.

Crops and masks k-space using a sampling mask.
Masks the k-space using a sampling mask.
"""

def __init__(self) -> None:
"""Inits :class:`ApplyMask`."""
super().__init__()
self.logger = logging.getLogger(type(self).__name__)

def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Calls :class:`ApplyMask`.

This assumes that a `sampling_mask` is present in the sample.

Parameters
----------
sample: Dict[str, Any]
Dict sample containing key `kspace`.

Returns
-------
Dict[str, Any]
Sample with new key `masked_kspace`.
"""
kspace = sample["kspace"]

assert "sampling_mask" in sample, "Key 'sampling_mask' not found in sample."
sampling_mask = sample["sampling_mask"]

masked_kspace, _ = T.apply_mask(kspace, sampling_mask)
sample["masked_kspace"] = masked_kspace

return sample


class CropKspace(DirectModule):
"""Data Transformer for training MRI reconstruction models.

Crops the k-space by:
* It first projects the k-space to the image-domain via the backward operator,
* It crops the back-projected k-space to specified shape or key,
* It transforms the cropped back-projected k-space to the k-space domain via the forward operator.
"""

def __init__(
self,
crop: Union[None, Tuple[int, ...]],
use_seed: bool = True,
crop: Union[str, Tuple[int, ...], List[int]],
forward_operator: Callable = T.fft2,
backward_operator: Callable = T.ifft2,
image_space_center_crop: bool = False,
random_crop_sampler_type: Optional[str] = "uniform",
random_crop_sampler_use_seed: Optional[bool] = True,
random_crop_sampler_gaussian_sigma: Optional[List[float]] = None,
) -> None:
"""Inits :class:`CropAndMask`.
"""Inits :class:`CropKspace`.

Parameters
----------
crop: tuple of ints or None
Size to crop input_image to.
use_seed: bool
If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
the same mask every time. Default: True.
crop: tuple of ints or str
Shape to crop the input to or a string pointing to a crop key (e.g. `reconstruction_size`).
forward_operator: Callable
The forward operator, e.g. some form of FFT (centered or uncentered).
Default: :class:`direct.data.transforms.fft2`.
Expand All @@ -191,87 +229,71 @@ def __init__(
If set, the crop in the data will be taken in the center
random_crop_sampler_type: Optional[str]
If "uniform" the random cropping will be done by uniformly sampling `crop`, as opposed to `gaussian` which
will sample from a gaussian distribution. Default: "uniform".
will sample from a gaussian distribution. If `image_space_center_crop` is True, then this is ignored.
Default: "uniform".
random_crop_sampler_use_seed: bool
If true, a pseudo-random number based on the filename is computed so that every slice of the volume
is cropped the same way. Default: True.
random_crop_sampler_gaussian_sigma: Optional[List[float]]
Standard variance of the gaussian when `random_crop_sampler_type` is `gaussian`.
If `image_space_center_crop` is True, then this is ignored. Default: None.
"""
super().__init__()
self.logger = logging.getLogger(type(self).__name__)

self.use_seed = use_seed
self.image_space_center_crop = image_space_center_crop

if not (isinstance(crop, (Iterable, str))):
raise ValueError(
f"Invalid input for `crop`. Received {crop}. Can be a list of tuple of integers or a string."
)
self.crop = crop
self.random_crop_sampler_type = random_crop_sampler_type
if self.crop:
if self.image_space_center_crop:
self.crop_func = T.complex_center_crop
else:
self.crop_func = functools.partial(T.complex_random_crop, sampler=self.random_crop_sampler_type)

if image_space_center_crop:
self.crop_func = T.complex_center_crop
else:
self.crop_func = functools.partial(
T.complex_random_crop, sampler=random_crop_sampler_type, sigma=random_crop_sampler_gaussian_sigma
)
self.random_crop_sampler_use_seed = random_crop_sampler_use_seed

self.forward_operator = forward_operator
self.backward_operator = backward_operator

self.image_space_center_crop = image_space_center_crop

def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Calls :class:`CropAndMask`.
"""Calls :class:`CropKspace`.

Parameters
----------
sample: Dict[str, Any]
Dict sample.
Dict sample containing key `kspace`.

Returns
-------
Dict[str, Any]
Cropped and masked sample.
"""
# Shape (coil, [slice], height, width, complex=2)
kspace = sample["kspace"]

# Image-space croppable objects
croppable_images = ["sensitivity_map", "input_image"]
kspace = sample["kspace"] # shape (coil, height, width, complex=2)

# Shape (coil, [slice], height, width, complex=2) if not None
sensitivity_map = sample.get("sensitivity_map", None)
# Shape (1, [slice], height, width, 1)
sampling_mask = sample["sampling_mask"]
# Shape (coil, [slice], height, width, complex=2)
backprojected_kspace = self.backward_operator(kspace)

# TODO: Also create a kspace-like crop function
if self.crop:
backprojected_kspace = self.crop_func(
[backprojected_kspace],
self.crop,
contiguous=True,
)
# Compute new k-space for the cropped input_image
kspace = self.forward_operator(backprojected_kspace)
for key in croppable_images:
if key in sample:
sample[key] = self.crop_func(
[sample[key]],
self.crop,
contiguous=True,
)
# TODO(gy): This is not correct, since cropping is done in the image space.
sampling_mask = self.crop_func(
[sampling_mask],
self.crop,
contiguous=True,
backprojected_kspace = self.backward_operator(kspace) # shape (coil, height, width, complex=2)

if isinstance(self.crop, str):
assert self.crop in sample, f"Not found {self.crop} key in sample."
crop_shape = sample[self.crop][:2]
else:
crop_shape = self.crop

cropper_args = {"data_list": [backprojected_kspace], "crop_shape": crop_shape, "contiguous": False}
if not self.image_space_center_crop:
cropper_args["seed"] = (
None if not self.random_crop_sampler_use_seed else tuple(map(ord, str(sample["filename"])))
)
masked_kspace, sampling_mask = T.apply_mask(kspace, sampling_mask)
# Shape ([slice], height, width)
sample["target"] = T.root_sum_of_squares(backprojected_kspace, dim=0)
# Shape (coil, [slice], height, width, complex=2)
sample["masked_kspace"] = masked_kspace
# Shape (1, [slice], height, width, 1)
sample["sampling_mask"] = sampling_mask
# Shape (coil, [slice], height, width, complex=2)
sample["kspace"] = kspace # The cropped kspace
cropped_backprojected_kspace = self.crop_func(**cropper_args)

if sensitivity_map is not None:
sample["sensitivity_map"] = sensitivity_map
# Compute new k-space for the cropped_backprojected_kspace
# shape (coil, new_height, new_width, complex=2)
sample["kspace"] = self.forward_operator(cropped_backprojected_kspace) # The cropped kspace

return sample

Expand Down Expand Up @@ -858,11 +880,22 @@ def build_mri_transforms(
# TODO: Use seed

mri_transforms: List[Callable] = [ToTensor()]
if crop:
mri_transforms += [
CropKspace(
crop=crop,
forward_operator=forward_operator,
backward_operator=backward_operator,
image_space_center_crop=image_center_crop,
random_crop_sampler_type=crop_type,
random_crop_sampler_use_seed=use_seed,
)
]
if mask_func:
mri_transforms += [
CreateSamplingMask(
mask_func,
shape=crop,
shape=(None if (isinstance(crop, str)) else crop),
use_seed=use_seed,
return_acs=estimate_sensitivity_maps,
)
Expand All @@ -876,13 +909,13 @@ def build_mri_transforms(
gaussian_sigma=sensitivity_maps_gaussian,
),
DeleteKeys(keys=["acs_mask"]),
CropAndMask(
crop,
forward_operator=forward_operator,
ComputeImage(
kspace_key="kspace",
target_key="target",
backward_operator=backward_operator,
image_space_center_crop=image_center_crop,
random_crop_sampler_type=crop_type,
type_reconstruction="rss",
),
ApplyMask(),
]
if estimate_body_coil_image and mask_func is not None:
mri_transforms.append(EstimateBodyCoilImage(mask_func, backward_operator=backward_operator, use_seed=use_seed))
Expand Down
30 changes: 26 additions & 4 deletions direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch
import torch.fft
from numpy.typing import ArrayLike

from direct.data.bbox import crop_to_bbox
from direct.utils import ensure_list, is_complex_data, is_power_of_two
Expand Down Expand Up @@ -247,6 +248,24 @@ def modulus(data: torch.Tensor, complex_axis: int = -1) -> torch.Tensor:
return (data**2).sum(complex_axis).sqrt() # noqa


def modulus_if_complex(data: torch.Tensor, complex_axis=-1) -> torch.Tensor:
"""Compute modulus if complex tensor (has complex axis).

Parameters
----------
data: torch.Tensor
complex_axis: int
Complex dimension along which the modulus will be calculated if that dimension is complex. Default: -1.

Returns
-------
torch.Tensor
"""
if is_complex_data(data, complex_axis=complex_axis):
return modulus(data=data, complex_axis=complex_axis)
return data


def roll_one_dim(data: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
"""Similar to roll but only for one dim

Expand Down Expand Up @@ -584,7 +603,7 @@ def center_crop(data: torch.Tensor, shape: Union[List[int], Tuple[int, ...]]) ->

def complex_center_crop(
data_list: Union[List[torch.Tensor], torch.Tensor],
shape: Union[List[int], Tuple[int, ...]],
crop_shape: Union[List[int], Tuple[int, ...]],
offset: int = 1,
contiguous: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
Expand All @@ -595,7 +614,7 @@ def complex_center_crop(
data_list: Union[List[torch.Tensor], torch.Tensor]
The complex input tensor to be center cropped. It should have at least 3 dimensions
and the cropping is applied along dimensions didx and didx+1 and the last dimensions should have a size of 2.
shape: List[int] or Tuple[int, ...]
crop_shape: List[int] or Tuple[int, ...]
The output shape. The shape should be smaller than the corresponding dimensions of data.
If one value is None, this is filled in by the image shape.
offset: int
Expand All @@ -616,7 +635,7 @@ def complex_center_crop(
bbox = [0] * ndim + image_shape

# Allow for False in crop directions
shape = [_ if _ else image_shape[idx + offset] for idx, _ in enumerate(shape)]
shape = [_ if _ else image_shape[idx + offset] for idx, _ in enumerate(crop_shape)]
for idx, _ in enumerate(shape):
bbox[idx + offset] = (image_shape[idx + offset] - shape[idx]) // 2
bbox[len(image_shape) + idx + offset] = shape[idx]
Expand Down Expand Up @@ -645,6 +664,7 @@ def complex_random_crop(
contiguous: bool = False,
sampler: str = "uniform",
sigma: Union[float, List[float], None] = None,
seed: Union[None, int, ArrayLike] = None,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Apply a random crop to the input data tensor or a list of complex.

Expand All @@ -663,6 +683,7 @@ def complex_random_crop(
Select the random indices from either a `uniform` or `gaussian` distribution (around the center)
sigma: float or list of float or None
Standard variance of the gaussian when sampler is `gaussian`. If not set will take 1/3th of image shape
seed: None, int or ArrayLike

Returns
-------
Expand Down Expand Up @@ -692,7 +713,8 @@ def complex_random_crop(
f"Bounding box limits have negative values, "
f"this is likely to data size being smaller than the crop size. Got {limits}"
)

if seed is not None:
np.random.seed(seed)
if sampler == "uniform":
lower_point = np.random.randint(0, limits + 1).tolist()
elif sampler == "gaussian":
Expand Down
2 changes: 1 addition & 1 deletion direct/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def training_loop(

metrics_dict = evaluate_dict(
metric_fns,
output.detach(),
T.modulus_if_complex(output.detach()),
data["target"].detach().to(self.device),
reduction="mean",
)
Expand Down
Loading