Skip to content

Commit

Permalink
Add new transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jun 5, 2024
1 parent 575b33a commit c686711
Show file tree
Hide file tree
Showing 3 changed files with 549 additions and 23 deletions.
42 changes: 28 additions & 14 deletions direct/data/datasets_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

"""Classes holding the typed configurations for the datasets."""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Optional

from omegaconf import MISSING

Expand All @@ -14,6 +16,7 @@
MaskSplitterType,
RandomFlipType,
ReconstructionType,
RescaleMode,
SensitivityMapType,
TransformsType,
)
Expand All @@ -37,9 +40,17 @@ class SensitivityMapEstimationTransformConfig(BaseConfig):
sensitivity_maps_gaussian: Optional[float] = 0.7


@dataclass
class AugmentationTransformConfig(BaseConfig):
rescale: Optional[tuple[int, ...]] = None
rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST
rescale_2d_if_3d: Optional[bool] = False
pad: Optional[tuple[int, ...]] = None


@dataclass
class RandomAugmentationTransformsConfig(BaseConfig):
random_rotation_degrees: Tuple[int, ...] = (-90, 90)
random_rotation_degrees: tuple[int, ...] = (-90, 90)
random_rotation_probability: float = 0.0
random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM
random_flip_probability: float = 0.0
Expand All @@ -62,8 +73,10 @@ class TransformsConfig(BaseConfig):
Configuration for the masking.
cropping : CropTransformConfig
Configuration for the cropping.
augmentation : AugmentationTransformConfig
Configuration for the augmentation. Currently only rescale and pad are supported.
random_augmentations : RandomAugmentationTransformsConfig
Configuration for the random augmentations.
Configuration for the random augmentations. Currently only random rotation, flip and reverse are supported.
padding_eps : float
Padding epsilon. Default is 0.001.
estimate_body_coil_image : bool
Expand All @@ -89,10 +102,10 @@ class TransformsConfig(BaseConfig):
To use SSL transforms, set transforms_type to `SSL_SSDU`. This will require additional parameters to be set:
mask_split_ratio, mask_split_acs_region, mask_split_keep_acs, mask_split_type, mask_split_gaussian_std.
Default is `TransformsType.SUPERVISED`.
mask_split_ratio : Tuple[float, ...]
mask_split_ratio : tuple[float, ...]
Ratio of the mask to split into input and target mask. Ignored if transforms_type is not `SSL_SSDU`.
Default is (0.4,).
mask_split_acs_region : Tuple[int, int]
mask_split_acs_region : tuple[int, int]
Region of the ACS k-space to keep in the input mask. Ignored if transforms_type is not `SSL_SSDU`.
Default is (0, 0).
mask_split_keep_acs : bool, optional
Expand All @@ -111,6 +124,7 @@ class TransformsConfig(BaseConfig):

masking: Optional[MaskingConfig] = MaskingConfig()
cropping: CropTransformConfig = CropTransformConfig()
augmentation: AugmentationTransformConfig = AugmentationTransformConfig()
random_augmentations: RandomAugmentationTransformsConfig = RandomAugmentationTransformsConfig()
padding_eps: float = 0.001
estimate_body_coil_image: bool = False
Expand All @@ -123,8 +137,8 @@ class TransformsConfig(BaseConfig):
use_seed: bool = True
transforms_type: TransformsType = TransformsType.SUPERVISED
# Next attributes are for the mask splitter in case of transforms_type is set to SSL_SSDU
mask_split_ratio: Tuple[float, ...] = (0.4,)
mask_split_acs_region: Tuple[int, int] = (0, 0)
mask_split_ratio: tuple[float, ...] = (0.4,)
mask_split_acs_region: tuple[int, int] = (0, 0)
mask_split_keep_acs: Optional[bool] = False
mask_split_type: MaskSplitterType = MaskSplitterType.GAUSSIAN
mask_split_gaussian_std: float = 3.0
Expand All @@ -146,21 +160,21 @@ class H5SliceConfig(DatasetConfig):
kspace_context: int = 0
pass_mask: bool = False
data_root: Optional[str] = None
filenames_filter: Optional[List[str]] = None
filenames_lists: Optional[List[str]] = None
filenames_filter: Optional[list[str]] = None
filenames_lists: Optional[list[str]] = None
filenames_lists_root: Optional[str] = None


@dataclass
class CMRxReconConfig(DatasetConfig):
regex_filter: Optional[str] = None
data_root: Optional[str] = None
filenames_filter: Optional[List[str]] = None
filenames_lists: Optional[List[str]] = None
filenames_filter: Optional[list[str]] = None
filenames_lists: Optional[list[str]] = None
filenames_lists_root: Optional[str] = None
kspace_key: str = "kspace_full"
compute_mask: bool = False
extra_keys: Optional[List[str]] = None
extra_keys: Optional[list[str]] = None
kspace_context: Optional[str] = None


Expand All @@ -181,11 +195,11 @@ class FakeMRIBlobsConfig(DatasetConfig):

@dataclass
class SheppLoganDatasetConfig(DatasetConfig):
shape: Tuple[int, int, int] = (100, 100, 30)
shape: tuple[int, int, int] = (100, 100, 30)
num_coils: int = 12
seed: Optional[int] = None
B0: float = 3.0
zlimits: Tuple[float, float] = (-0.929, 0.929)
zlimits: tuple[float, float] = (-0.929, 0.929)


@dataclass
Expand Down
Loading

0 comments on commit c686711

Please sign in to comment.