Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/docs minor issues #250

Merged
merged 11 commits into from
Jun 28, 2023
2 changes: 1 addition & 1 deletion direct/data/datasets_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@dataclass
class CropTransformConfig(BaseConfig):
crop: Optional[Tuple[int, int]] = None
crop: Optional[str] = None
crop_type: Optional[str] = "uniform"
image_center_crop: bool = False

Expand Down
6 changes: 4 additions & 2 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from direct.algorithms.mri_algorithms import EspiritCalibration
from direct.data import transforms as T
from direct.exceptions import ItemNotFoundException
from direct.types import DirectEnum, KspaceKey
from direct.types import DirectEnum, IntegerListOrTupleString, KspaceKey
from direct.utils import DirectModule, DirectTransform
from direct.utils.asserts import assert_complex

Expand Down Expand Up @@ -408,7 +408,9 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:

backprojected_kspace = self.backward_operator(kspace, dim=(1, 2)) # shape (coil, height, width, complex=2)

if isinstance(self.crop, str):
if isinstance(self.crop, IntegerListOrTupleString):
crop_shape = IntegerListOrTupleString(self.crop)
elif isinstance(self.crop, str):
assert self.crop in sample, f"Not found {self.crop} key in sample."
crop_shape = sample[self.crop][:2]
else:
Expand Down
4 changes: 2 additions & 2 deletions direct/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from direct.data.mri_transforms import build_mri_transforms
from direct.environment import setup_inference_environment
from direct.types import FileOrUrl, PathOrString
from direct.utils import chunks, remove_keys
from direct.utils import chunks, dict_flatten, remove_keys
from direct.utils.io import read_list
from direct.utils.writers import write_output_to_h5

Expand Down Expand Up @@ -135,7 +135,7 @@ def build_inference_transforms(env, mask_func: Callable, dataset_cfg: DictConfig
backward_operator=env.engine.backward_operator,
mask_func=mask_func,
)
transforms = partial_build_mri_transforms(**remove_keys(dataset_cfg.transforms, "masking"))
transforms = partial_build_mri_transforms(**dict_flatten(remove_keys(dataset_cfg.transforms, "masking")))
return transforms


Expand Down
82 changes: 82 additions & 0 deletions direct/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,85 @@ def __hash__(self) -> int:
class KspaceKey(DirectEnum):
kspace = "kspace"
masked_kspace = "masked_kspace"


class IntegerListOrTupleStringMeta(type):
"""Metaclass for the :class:`IntegerListOrTupleString` class.

Returns
-------
bool
True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.
"""

def __instancecheck__(cls, instance):
"""Check if the given instance is a valid representation of an IntegerListOrTupleString.

Parameters
----------
cls : type
The class being checked, i.e., IntegerListOrTupleStringMeta.
instance : object
The instance being checked.

Returns
-------
bool
True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.
"""
if isinstance(instance, str):
try:
assert (instance.startswith("[") and instance.endswith("]")) or (
instance.startswith("(") and instance.endswith(")")
)
elements = instance.strip()[1:-1].split(",")
integers = [int(element) for element in elements]
return all(isinstance(num, int) for num in integers)
except (AssertionError, ValueError, AttributeError):
pass
return False


class IntegerListOrTupleString(metaclass=IntegerListOrTupleStringMeta):
"""IntegerListOrTupleString class represents a list or tuple of integers based on a string representation.

Examples
--------
s1 = "[1, 2, 45, -1, 0]"
print(isinstance(s1, IntegerListOrTupleString)) # True
print(IntegerListOrTupleString(s1)) # [1, 2, 45, -1, 0]
print(type(IntegerListOrTupleString(s1))) # <class 'list'>
print(type(IntegerListOrTupleString(s1)[0])) # <class 'int'>

s2 = "(10, -9, 20)"
print(isinstance(s2, IntegerListOrTupleString)) # True
print(IntegerListOrTupleString(s2)) # (10, -9, 20)
print(type(IntegerListOrTupleString(s2))) # <class 'tuple'>
print(type(IntegerListOrTupleString(s2)[0])) # <class 'int'>

s3 = "[a, we, 2]"
print(isinstance(s3, IntegerListOrTupleString)) # False

s4 = "(1, 2, 3]"
print(isinstance(s4 IntegerListOrTupleString)) # False
"""

def __new__(cls, string):
"""
Create a new instance of IntegerListOrTupleString based on the given string representation.

Parameters
----------
string : str
The string representation of the integer list or tuple.

Returns
-------
list or tuple
A new instance of IntegerListOrTupleString.
"""
list_or_tuple = list if string.startswith("[") else tuple
string = string.strip()[1:-1] # Remove outer brackets
elements = string.split(",")
integers = [int(element) for element in elements]
return list_or_tuple(integers)
27 changes: 21 additions & 6 deletions docs/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
Configuration
=============

To perform experiments for training, validation or inference, a configuration file with an extension ``.yaml`` must be defined which includes all experiments parameters such as models, datasets, etc. The following is a template for the configuration file. Accepted arguments are the parameters as defined in the ``config<>.py`` file for each function/class. For instance, accepted arguments for training are the parameters as defined in ``TrainingConfig``. A list of our configuration files can be found in the `projects <../projects/>`_ folder.
To perform experiments for training, validation or inference, a configuration file
with an extension `.yaml` must be defined which includes all experiments parameters such as models,
datasets, etc. The following is a template for the configuration file.

.. code-block:: yaml

model:
model_name: <nn_model_path>
model_parameter_1: <nn_model_paramter_1>
model_parameter_2: <nn_model_paramter_2>
...

additional_models:
sensitivity_model:
model_name: <nn_sensitivity_model_path>
Expand All @@ -31,9 +33,13 @@ To perform experiments for training, validation or inference, a configuration fi
- <path_to_list_1_for_Dataset1>
- <path_to_list_2_for_Dataset1>
transforms:
estimate_sensitivity_maps: <true_or_false>
scaling_key: <scaling_key>
image_center_crop: <true_or_false>
cropping:
crop: <shape_or_str>
image_center_crop: <true_or_false>
sensitivity_map_estimation:
estimate_sensitivity_maps: <true_or_false>
normalization:
scaling_key: <stringg>
masking:
name: MaskingFunctionName
accelerations: [acceleration_1, accelaration_2, ...]
Expand Down Expand Up @@ -109,3 +115,12 @@ To perform experiments for training, validation or inference, a configuration fi
logging:
tensorboard:
num_images: <num_images>

The following configuration files are accepted for each field:

* physics, training, and validation configurations: ``direct/config/defaults.py``
* transforms configurations: ``direct/data/datasets_config.py``
* model configurations: ``direct/nn/<model_name>/config.py``

A list of our configuration files can be found in
the `projects <https://github.com/NKI-AI/direct/tree/main/projects>`_ folder.
9 changes: 7 additions & 2 deletions docs/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,14 @@ Follow the steps below:
self.filenames_filter = filenames_filter

self.text_description = text_description
self.ndim = # 2 or 3
self.volume_indices = self.set_volume_indices(...)
...

def set_volume_indices(self, ...):
...

def self.get_dataset_len(self):
def get_dataset_len(self):
...

def __len__(self):
Expand All @@ -87,11 +91,12 @@ should split three-dimensional data to slices of two-dimensional data.

.. code-block:: python

...

@dataclass
class MyDatasetConfig(BaseConfig):
...
name: str = "MyNew"
lists: List[str] = field(default_factory=lambda: [])
transforms: BaseConfig = TransformsConfig()
text_description: Optional[str] = None
...
Expand Down
54 changes: 27 additions & 27 deletions projects/calgary_campinas/configs/base_cirim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,31 @@ physics:
backward_operator: ifft2(centered=False)
training:
datasets:
# Two datasets, only difference is the shape, so the data can be collated for larger batches
# Two datasets, only difference is the shape, so the data can be collated for larger batches. R=5
- name: CalgaryCampinas
lists:
filenames_lists:
- ../lists/train/12x218x170_train.lst
transforms:
crop: null
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
image_center_crop: false
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
masking:
name: CalgaryCampinas
accelerations: [5, 10]
crop_outer_slices: true
- name: CalgaryCampinas
lists:
filenames_lists:
- ../lists/train/12x218x180_train.lst
transforms:
crop: null
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
image_center_crop: false
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [5, 10]
Expand Down Expand Up @@ -57,19 +61,25 @@ validation:
# Twice the same dataset but a different acceleration factor
- name: CalgaryCampinas
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [5]
crop_outer_slices: true
text_description: 5x # Description for logging
- name: CalgaryCampinas
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [10]
Expand Down Expand Up @@ -98,13 +108,3 @@ additional_models:
logging:
tensorboard:
num_images: 4
inference:
batch_size: 8
dataset:
name: CalgaryCampinas
crop_outer_slices: true
text_description: inference
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
52 changes: 26 additions & 26 deletions projects/calgary_campinas/configs/base_conjgradnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,30 @@ training:
filenames_lists:
- ../lists/train/12x218x170_train.lst
transforms:
crop: null
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
image_center_crop: false
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
masking:
name: CalgaryCampinas
accelerations: [5, 10]
crop_outer_slices: false
crop_outer_slices: true
- name: CalgaryCampinas
filenames_lists:
- ../lists/train/12x218x180_train.lst
transforms:
crop: null
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
image_center_crop: false
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [5, 10]
crop_outer_slices: false
crop_outer_slices: true
batch_size: 2 # This is the batch size per GPU!
optimizer: Adam
lr: 0.001
Expand All @@ -54,19 +58,25 @@ validation:
# Twice the same dataset but a different acceleration factor
- name: CalgaryCampinas
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [5]
crop_outer_slices: true
text_description: 5x # Description for logging
- name: CalgaryCampinas
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [10]
Expand Down Expand Up @@ -100,13 +110,3 @@ additional_models:
logging:
tensorboard:
num_images: 4
inference:
batch_size: 8
dataset:
name: CalgaryCampinas
crop_outer_slices: true
text_description: inference
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
Loading