Skip to content

Commit

Permalink
SSL transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Sep 28, 2023
1 parent 6c87021 commit caa22ee
Show file tree
Hide file tree
Showing 11 changed files with 1,660 additions and 6 deletions.
1 change: 1 addition & 0 deletions direct/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class InferenceConfig(BaseConfig):
@dataclass
class ModelConfig(BaseConfig):
model_name: str = MISSING
engine_name: Optional[str] = None


@dataclass
Expand Down
7 changes: 7 additions & 0 deletions direct/data/datasets_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ class TransformsConfig(BaseConfig):
image_recon_type: str = "rss"
pad_coils: Optional[int] = None
use_seed: bool = True
# Next attrs are for SSL transforms
ssl_transforms: Optional[str] = None
mask_split_ratio: Tuple[float, ...] = (0.4,)
mask_split_acs_region: Tuple[int, int] = (0, 0)
mask_split_keep_acs: Optional[bool] = False
mask_split_type: str = "gaussian"
mask_split_gaussian_std: float = 3.0


@dataclass
Expand Down
240 changes: 238 additions & 2 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from direct.algorithms.mri_algorithms import EspiritCalibration
from direct.data import transforms as T
from direct.exceptions import ItemNotFoundException
from direct.ssl.ssl import (
GaussianMaskSplitterModule,
HalfMaskSplitterModule,
HalfSplitType,
MaskSplitterType,
UniformMaskSplitterModule,
)
from direct.types import DirectEnum, IntegerListOrTupleString, KspaceKey, TransformKey
from direct.utils import DirectModule, DirectTransform
from direct.utils.asserts import assert_complex
Expand Down Expand Up @@ -1365,6 +1372,8 @@ def __call__(self, *args, **kwargs):
ComputeScalingFactor = ModuleWrapper(ComputeScalingFactorModule, toggle_dims=True)
Normalize = ModuleWrapper(NormalizeModule, toggle_dims=False)
WhitenData = ModuleWrapper(WhitenDataModule, toggle_dims=False)
GaussianMaskSplitter = ModuleWrapper(GaussianMaskSplitterModule, toggle_dims=True)
UniformMaskSplitter = ModuleWrapper(UniformMaskSplitterModule, toggle_dims=True)


class ToTensor(DirectTransform):
Expand Down Expand Up @@ -1647,7 +1656,7 @@ def build_post_mri_transforms(
return Compose(mri_transforms)


def build_mri_transforms(
def build_supervised_mri_transforms(
forward_operator: Callable,
backward_operator: Callable,
mask_func: Optional[Callable],
Expand Down Expand Up @@ -1848,7 +1857,10 @@ def build_mri_transforms(
ComputeScalingFactor(
normalize_key=scaling_key, percentile=scale_percentile, scaling_factor_key=TransformKey.scaling_factor
),
Normalize(scaling_factor_key=TransformKey.scaling_factor),
Normalize(
scaling_factor_key=TransformKey.scaling_factor,
keys_to_normalize=[KspaceKey.kspace, KspaceKey.masked_kspace],
),
]

mri_transforms += [
Expand All @@ -1864,3 +1876,227 @@ def build_mri_transforms(
mri_transforms += [DeleteKeys(keys=[KspaceKey.kspace])]

return Compose(mri_transforms)


def build_mri_transforms(
forward_operator: Callable,
backward_operator: Callable,
mask_func: Optional[Callable],
crop: Optional[Union[Tuple[int, int], str]] = None,
crop_type: Optional[str] = "uniform",
image_center_crop: bool = True,
random_rotation: bool = False,
random_rotation_degrees: Optional[Sequence[int]] = (-90, 90),
random_rotation_probability: Optional[float] = 0.5,
random_flip: bool = False,
random_flip_type: Optional[RandomFlipType] = RandomFlipType.random,
random_flip_probability: Optional[float] = 0.5,
random_reverse: bool = False,
random_reverse_probability: float = 0.5,
padding_eps: float = 0.0001,
estimate_body_coil_image: bool = False,
estimate_sensitivity_maps: bool = True,
sensitivity_maps_type: SensitivityMapType = SensitivityMapType.rss_estimate,
sensitivity_maps_gaussian: Optional[float] = None,
sensitivity_maps_espirit_threshold: Optional[float] = 0.05,
sensitivity_maps_espirit_kernel_size: Optional[int] = 6,
sensitivity_maps_espirit_crop: Optional[float] = 0.95,
sensitivity_maps_espirit_max_iters: Optional[int] = 30,
delete_acs_mask: bool = True,
delete_kspace: bool = True,
image_recon_type: ReconstructionType = ReconstructionType.rss,
pad_coils: Optional[int] = None,
scaling_key: TransformKey = TransformKey.masked_kspace,
scale_percentile: Optional[float] = 0.99,
use_seed: bool = True,
ssl_transforms: Optional[str] = None,
mask_split_ratio: Union[float, List[float], Tuple[float, ...]] = 0.4,
mask_split_acs_region: Union[List[int], Tuple[int, int]] = (0, 0),
mask_split_keep_acs: Optional[bool] = False,
mask_split_type: MaskSplitterType = MaskSplitterType.gaussian,
mask_split_gaussian_std: float = 3.0,
mask_split_half_direction: HalfSplitType = HalfSplitType.vertical,
) -> object:
"""Build transforms for MRI.
- Converts input to (complex-valued) tensor.
- Adds a sampling mask if `mask_func` is defined.
- Adds coil sensitivities and / or the body coil_image
- Crops the input data if needed and masks the fully sampled k-space.
- Add a target.
- Normalize input data.
- Pads the coil dimension.
Parameters
----------
forward_operator : Callable
The forward operator, e.g. some form of FFT (centered or uncentered).
backward_operator : Callable
The backward operator, e.g. some form of inverse FFT (centered or uncentered).
mask_func : Callable or None
A function which creates a sampling mask of the appropriate shape.
crop : Tuple[int, int] or str, Optional
If not None, this will transform the "kspace" to an image domain, crop it, and transform it back.
If a tuple of integers is given then it will crop the backprojected kspace to that size. If
"reconstruction_size" is given, then it will crop the backprojected kspace according to it, but
a key "reconstruction_size" must be present in the sample. Default: None.
crop_type : Optional[str]
Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform".
image_center_crop : bool
If True the backprojected kspace will be cropped around the center, otherwise randomly.
This will be ignored if `crop` is None. Default: True.
random_rotation : bool
If True, random rotations will be applied of `random_rotation_degrees` degrees, with probability
`random_rotation_probability`. Default: False.
random_rotation_degrees : Sequence[int], optional
Default: (-90, 90).
random_rotation_probability : float, optional
Default: 0.5.
random_flip : bool
If True, random rotation of `random_flip_type` type, with probability `random_flip_probability`. Default: False.
random_flip_type : RandomFlipType, optional
Default: RandomFlipType.random.
random_flip_probability : float, optional
Default: 0.5.
random_reverse : bool
If True will perform random reversion along the time or slice dimension (2). Default: False.
random_reverse_probability : float
Default: 0.5.
padding_eps: float
Padding epsilon. Default: 0.0001.
estimate_body_coil_image : bool
Estimate body coil image. Default: False.
estimate_sensitivity_maps : bool
Estimate sensitivity maps using the acs region. Default: True.
sensitivity_maps_type: sensitivity_maps_type
Can be SensitivityMapType.rss_estimate, SensitivityMapType.unit or SensitivityMapType.espirit.
Will be ignored if `estimate_sensitivity_maps`==False. Default: SensitivityMapType.rss_estimate.
sensitivity_maps_gaussian : float
Optional sigma for gaussian weighting of sensitivity map.
sensitivity_maps_espirit_threshold: float, optional
Threshold for the calibration matrix when `type_of_map`=="espirit". Default: 0.05.
sensitivity_maps_espirit_kernel_size: int, optional
Kernel size for the calibration matrix when `type_of_map`=="espirit". Default: 6.
sensitivity_maps_espirit_crop: float, optional
Output eigenvalue cropping threshold when `type_of_map`=="espirit". Default: 0.95.
sensitivity_maps_espirit_max_iters: int, optional
Power method iterations when `type_of_map`=="espirit". Default: 30.
delete_acs_mask : bool
If True will delete key `acs_mask`. Default: True.
delete_kspace : bool
If True will delete key `kspace` (fully sampled k-space). Default: True.
image_recon_type : ReconstructionType
Type to reconstruct target image. Default: ReconstructionType.rss.
pad_coils : int
Number of coils to pad data to.
scaling_key : KspaceKey
Key in sample to scale scalable items in sample. Default: KspaceKey.masked_kspace.
scale_percentile : float, optional
Data will be rescaled with the given percentile. If None, the division is done by the maximum. Default: 0.99
use_seed : bool
If true, a pseudo-random number based on the filename is computed so that every slice of the volume get
the same mask every time. Default: True.
ssl_transforms : str, optional
If not None, "ssdu" or "dualssl" can be used. Default: None.
mask_split_ratio : Union[float, List[float], Tuple[float, ...]]
The ratio(s) of the sampling mask splitting. If `ssl_transforms` is None, this is ignored.
mask_split_acs_region : Union[List[int], Tuple[int, int]]
A rectangle for the acs region that will be used in the :math:`\Theta` mask (if `ssl_transforms` = "dualssl")
or input mask (if `ssl_transforms` = "ssdu"). Default: (0, 0).
mask_split_keep_acs : Optional[bool]
If True, acs region according to the "acs_mask" of the sample will be used in both mask splits. Default: False.
mask_split_type : MaskSplitterType
How the sampling mask will be split. Can be "uniform" or "gaussian". Default: "gaussian".
mask_split_gaussian_std : float
Standard deviation of gaussian mask splitting. Ignored if `mask_split_type` is not "gaussian". Default: 3.0.
mask_split_half_direction : HalfSplitType
Split type if `mask_split_type` is "vertical. Can be "vertical", "horizontal", "diagonal_left",
or "diagonal_right". Ignored if `mask_split_type` is not "vertical". Default: HalfSplitType.vertical.
Returns
-------
object: Callable
An MRI transformation object.
"""
mri_transforms = build_supervised_mri_transforms(
forward_operator=forward_operator,
backward_operator=backward_operator,
mask_func=mask_func,
crop=crop,
crop_type=crop_type,
image_center_crop=image_center_crop,
random_rotation=random_rotation,
random_rotation_degrees=random_rotation_degrees,
random_rotation_probability=random_rotation_probability,
random_flip=random_flip,
random_flip_type=random_flip_type,
random_flip_probability=random_flip_probability,
random_reverse=random_reverse,
random_reverse_probability=random_reverse_probability,
padding_eps=padding_eps,
estimate_sensitivity_maps=estimate_sensitivity_maps,
sensitivity_maps_type=sensitivity_maps_type,
estimate_body_coil_image=estimate_body_coil_image,
sensitivity_maps_gaussian=sensitivity_maps_gaussian,
sensitivity_maps_espirit_threshold=sensitivity_maps_espirit_threshold,
sensitivity_maps_espirit_kernel_size=sensitivity_maps_espirit_kernel_size,
sensitivity_maps_espirit_crop=sensitivity_maps_espirit_crop,
sensitivity_maps_espirit_max_iters=sensitivity_maps_espirit_max_iters,
delete_acs_mask=delete_acs_mask if ssl_transforms is None else False,
delete_kspace=delete_kspace if ssl_transforms is None else False,
image_recon_type=image_recon_type,
pad_coils=pad_coils,
scaling_key=scaling_key,
scale_percentile=scale_percentile,
use_seed=use_seed,
).transforms

mri_transforms += [AddBooleanKeysModule(["is_ssl_training"], [False if ssl_transforms is None else True])]

if ssl_transforms is None:
return Compose(mri_transforms)

assert ssl_transforms in ["dualssl", "ssdu", "noisier2noise"]
mask_splitter_kwargs = {
"ratio": mask_split_ratio,
"acs_region": mask_split_acs_region,
"keep_acs": mask_split_keep_acs,
"use_seed": use_seed,
"kspace_key": "masked_kspace",
}
mri_transforms += [
GaussianMaskSplitter(**mask_splitter_kwargs, std_scale=mask_split_gaussian_std)
if mask_split_type == "gaussian"
else UniformMaskSplitter(**mask_splitter_kwargs)
if mask_split_type == "uniform"
else HalfMaskSplitterModule(
**{k: v for k, v in mask_splitter_kwargs.items() if k != "ratio"}, direction=mask_split_half_direction
),
DeleteKeys(["acs_mask"]),
]
if ssl_transforms == "ssdu":
mri_transforms += [
RenameKeys(
["lambda_sampling_mask", "theta_sampling_mask", "theta_masked_kspace", "lambda_masked_kspace"],
["target_sampling_mask", "input_sampling_mask", "input_kspace", "kspace"],
),
DeleteKeys(["masked_kspace", "sampling_mask"]),
]
elif ssl_transforms == "noisier2noise":
mri_transforms += [
DeleteKeys(["lambda_sampling_mask", "lambda_masked_kspace"]), # Do not need 2nd mask for Noisier2Noise
RenameKeys(
["theta_sampling_mask", "theta_masked_kspace", "masked_kspace"],
["noisier_sampling_mask", "noisier_kspace", "kspace"],
),
]
mri_transforms += [
ComputeImage(
kspace_key="kspace",
target_key="target",
backward_operator=backward_operator,
type_reconstruction=image_recon_type,
)
]

return Compose(mri_transforms)
7 changes: 3 additions & 4 deletions direct/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,14 @@ def initialize_models_from_config(
# TODO(jt): Model name is not used here.
additional_models = {}
for k, v in cfg.additional_models.items():
# Remove model_name key
curr_model = models[k]
curr_model_cfg = {kk: vv for kk, vv in v.items() if kk != "model_name"}
curr_model_cfg = {kk: vv for kk, vv in v.items() if kk not in ["engine_name", "model_name"]}
additional_models[k] = curr_model(**curr_model_cfg)

model = models["model"](
forward_operator=forward_operator,
backward_operator=backward_operator,
**{k: v for (k, v) in cfg.model.items()},
**{k: v for (k, v) in cfg.model.items() if k != "engine_name"},
).to(device)

# Log total number of parameters
Expand Down Expand Up @@ -267,7 +266,7 @@ def setup_engine(
# There is a bit of repetition here, but the warning provided is more descriptive
# TODO(jt): Try to find a way to combine this with the setup above.
model_name_short = cfg.model.model_name.split(".")[0]
engine_name = cfg.model.model_name.split(".")[-1] + "Engine"
engine_name = cfg.model.engine_name if cfg.model.engine_name else cfg.model.model_name.split(".")[-1] + "Engine"

try:
engine_class = str_to_class(
Expand Down
2 changes: 2 additions & 0 deletions direct/nn/ssl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
Loading

0 comments on commit caa22ee

Please sign in to comment.