Skip to content

Commit

Permalink
Minor Code and Documentation Fixes (NKI-AI#250, closes NKI-AI#243, NK…
Browse files Browse the repository at this point in the history
…I-AI#244, NKI-AI#245, NKI-AI#246, NKI-AI#247, NKI-AI#249) (#3)

Minor code fixes:
* Fix transforms creation for inference using `dict_flatten`

Config fixes:
* Update projects/<configs>  to new transforms config:
``` 
cropping:
    crop: null
sensitivity_map_estimation:
    estimate_sensitivity_maps: true  # Estimate the sensitivity map on the ACS
normalization:
    scaling_key: masked_kspace
```

Documentation fixes:
* Update some documentation files
  • Loading branch information
georgeyiasemis authored Aug 31, 2023
1 parent b380e2e commit e3cca85
Show file tree
Hide file tree
Showing 71 changed files with 1,903 additions and 1,234 deletions.
2 changes: 1 addition & 1 deletion direct/data/datasets_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@dataclass
class CropTransformConfig(BaseConfig):
crop: Optional[Tuple[int, int]] = None
crop: Optional[str] = None
crop_type: Optional[str] = "uniform"
image_center_crop: bool = False

Expand Down
6 changes: 4 additions & 2 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from direct.algorithms.mri_algorithms import EspiritCalibration
from direct.data import transforms as T
from direct.exceptions import ItemNotFoundException
from direct.types import DirectEnum, KspaceKey
from direct.types import DirectEnum, IntegerListOrTupleString, KspaceKey
from direct.utils import DirectModule, DirectTransform
from direct.utils.asserts import assert_complex

Expand Down Expand Up @@ -408,7 +408,9 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:

backprojected_kspace = self.backward_operator(kspace, dim=(1, 2)) # shape (coil, height, width, complex=2)

if isinstance(self.crop, str):
if isinstance(self.crop, IntegerListOrTupleString):
crop_shape = IntegerListOrTupleString(self.crop)
elif isinstance(self.crop, str):
assert self.crop in sample, f"Not found {self.crop} key in sample."
crop_shape = sample[self.crop][:2]
else:
Expand Down
4 changes: 2 additions & 2 deletions direct/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from direct.data.mri_transforms import build_mri_transforms
from direct.environment import setup_inference_environment
from direct.types import FileOrUrl, PathOrString
from direct.utils import chunks, remove_keys
from direct.utils import chunks, dict_flatten, remove_keys
from direct.utils.io import read_list
from direct.utils.writers import write_output_to_h5

Expand Down Expand Up @@ -135,7 +135,7 @@ def build_inference_transforms(env, mask_func: Callable, dataset_cfg: DictConfig
backward_operator=env.engine.backward_operator,
mask_func=mask_func,
)
transforms = partial_build_mri_transforms(**remove_keys(dataset_cfg.transforms, "masking"))
transforms = partial_build_mri_transforms(**dict_flatten(remove_keys(dataset_cfg.transforms, "masking")))
return transforms


Expand Down
82 changes: 82 additions & 0 deletions direct/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,85 @@ def __hash__(self) -> int:
class KspaceKey(DirectEnum):
kspace = "kspace"
masked_kspace = "masked_kspace"


class IntegerListOrTupleStringMeta(type):
"""Metaclass for the :class:`IntegerListOrTupleString` class.
Returns
-------
bool
True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.
"""

def __instancecheck__(cls, instance):
"""Check if the given instance is a valid representation of an IntegerListOrTupleString.
Parameters
----------
cls : type
The class being checked, i.e., IntegerListOrTupleStringMeta.
instance : object
The instance being checked.
Returns
-------
bool
True if the instance is a valid representation of IntegerListOrTupleString, False otherwise.
"""
if isinstance(instance, str):
try:
assert (instance.startswith("[") and instance.endswith("]")) or (
instance.startswith("(") and instance.endswith(")")
)
elements = instance.strip()[1:-1].split(",")
integers = [int(element) for element in elements]
return all(isinstance(num, int) for num in integers)
except (AssertionError, ValueError, AttributeError):
pass
return False


class IntegerListOrTupleString(metaclass=IntegerListOrTupleStringMeta):
"""IntegerListOrTupleString class represents a list or tuple of integers based on a string representation.
Examples
--------
s1 = "[1, 2, 45, -1, 0]"
print(isinstance(s1, IntegerListOrTupleString)) # True
print(IntegerListOrTupleString(s1)) # [1, 2, 45, -1, 0]
print(type(IntegerListOrTupleString(s1))) # <class 'list'>
print(type(IntegerListOrTupleString(s1)[0])) # <class 'int'>
s2 = "(10, -9, 20)"
print(isinstance(s2, IntegerListOrTupleString)) # True
print(IntegerListOrTupleString(s2)) # (10, -9, 20)
print(type(IntegerListOrTupleString(s2))) # <class 'tuple'>
print(type(IntegerListOrTupleString(s2)[0])) # <class 'int'>
s3 = "[a, we, 2]"
print(isinstance(s3, IntegerListOrTupleString)) # False
s4 = "(1, 2, 3]"
print(isinstance(s4 IntegerListOrTupleString)) # False
"""

def __new__(cls, string):
"""
Create a new instance of IntegerListOrTupleString based on the given string representation.
Parameters
----------
string : str
The string representation of the integer list or tuple.
Returns
-------
list or tuple
A new instance of IntegerListOrTupleString.
"""
list_or_tuple = list if string.startswith("[") else tuple
string = string.strip()[1:-1] # Remove outer brackets
elements = string.split(",")
integers = [int(element) for element in elements]
return list_or_tuple(integers)
27 changes: 21 additions & 6 deletions docs/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
Configuration
=============

To perform experiments for training, validation or inference, a configuration file with an extension ``.yaml`` must be defined which includes all experiments parameters such as models, datasets, etc. The following is a template for the configuration file. Accepted arguments are the parameters as defined in the ``config<>.py`` file for each function/class. For instance, accepted arguments for training are the parameters as defined in ``TrainingConfig``. A list of our configuration files can be found in the `projects <../projects/>`_ folder.
To perform experiments for training, validation or inference, a configuration file
with an extension `.yaml` must be defined which includes all experiments parameters such as models,
datasets, etc. The following is a template for the configuration file.

.. code-block:: yaml
model:
model_name: <nn_model_path>
model_parameter_1: <nn_model_paramter_1>
model_parameter_2: <nn_model_paramter_2>
...
additional_models:
sensitivity_model:
model_name: <nn_sensitivity_model_path>
Expand All @@ -31,9 +33,13 @@ To perform experiments for training, validation or inference, a configuration fi
- <path_to_list_1_for_Dataset1>
- <path_to_list_2_for_Dataset1>
transforms:
estimate_sensitivity_maps: <true_or_false>
scaling_key: <scaling_key>
image_center_crop: <true_or_false>
cropping:
crop: <shape_or_str>
image_center_crop: <true_or_false>
sensitivity_map_estimation:
estimate_sensitivity_maps: <true_or_false>
normalization:
scaling_key: <stringg>
masking:
name: MaskingFunctionName
accelerations: [acceleration_1, accelaration_2, ...]
Expand Down Expand Up @@ -109,3 +115,12 @@ To perform experiments for training, validation or inference, a configuration fi
logging:
tensorboard:
num_images: <num_images>
The following configuration files are accepted for each field:

* physics, training, and validation configurations: ``direct/config/defaults.py``
* transforms configurations: ``direct/data/datasets_config.py``
* model configurations: ``direct/nn/<model_name>/config.py``

A list of our configuration files can be found in
the `projects <https://github.com/NKI-AI/direct/tree/main/projects>`_ folder.
9 changes: 7 additions & 2 deletions docs/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,14 @@ Follow the steps below:
self.filenames_filter = filenames_filter
self.text_description = text_description
self.ndim = # 2 or 3
self.volume_indices = self.set_volume_indices(...)
...
def set_volume_indices(self, ...):
...
def self.get_dataset_len(self):
def get_dataset_len(self):
...
def __len__(self):
Expand All @@ -87,11 +91,12 @@ should split three-dimensional data to slices of two-dimensional data.

.. code-block:: python
...
@dataclass
class MyDatasetConfig(BaseConfig):
...
name: str = "MyNew"
lists: List[str] = field(default_factory=lambda: [])
transforms: BaseConfig = TransformsConfig()
text_description: Optional[str] = None
...
Expand Down
54 changes: 27 additions & 27 deletions projects/calgary_campinas/configs/base_cirim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,31 @@ physics:
backward_operator: ifft2(centered=False)
training:
datasets:
# Two datasets, only difference is the shape, so the data can be collated for larger batches
# Two datasets, only difference is the shape, so the data can be collated for larger batches. R=5
- name: CalgaryCampinas
lists:
filenames_lists:
- ../lists/train/12x218x170_train.lst
transforms:
crop: null
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
image_center_crop: false
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
masking:
name: CalgaryCampinas
accelerations: [5, 10]
crop_outer_slices: true
- name: CalgaryCampinas
lists:
filenames_lists:
- ../lists/train/12x218x180_train.lst
transforms:
crop: null
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
image_center_crop: false
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [5, 10]
Expand Down Expand Up @@ -57,19 +61,25 @@ validation:
# Twice the same dataset but a different acceleration factor
- name: CalgaryCampinas
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [5]
crop_outer_slices: true
text_description: 5x # Description for logging
- name: CalgaryCampinas
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [10]
Expand Down Expand Up @@ -98,13 +108,3 @@ additional_models:
logging:
tensorboard:
num_images: 4
inference:
batch_size: 8
dataset:
name: CalgaryCampinas
crop_outer_slices: true
text_description: inference
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
52 changes: 26 additions & 26 deletions projects/calgary_campinas/configs/base_conjgradnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,30 @@ training:
filenames_lists:
- ../lists/train/12x218x170_train.lst
transforms:
crop: null
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
image_center_crop: false
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
masking:
name: CalgaryCampinas
accelerations: [5, 10]
crop_outer_slices: false
crop_outer_slices: true
- name: CalgaryCampinas
filenames_lists:
- ../lists/train/12x218x180_train.lst
transforms:
crop: null
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
scaling_key: masked_kspace # Compute the image normalization based on the masked_kspace maximum
image_center_crop: false
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [5, 10]
crop_outer_slices: false
crop_outer_slices: true
batch_size: 2 # This is the batch size per GPU!
optimizer: Adam
lr: 0.001
Expand All @@ -54,19 +58,25 @@ validation:
# Twice the same dataset but a different acceleration factor
- name: CalgaryCampinas
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [5]
crop_outer_slices: true
text_description: 5x # Description for logging
- name: CalgaryCampinas
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
cropping:
crop: null
sensitivity_map_estimation:
estimate_sensitivity_maps: true # Estimate the sensitivity map on the ACS
normalization:
scaling_key: masked_kspace
masking:
name: CalgaryCampinas
accelerations: [10]
Expand Down Expand Up @@ -100,13 +110,3 @@ additional_models:
logging:
tensorboard:
num_images: 4
inference:
batch_size: 8
dataset:
name: CalgaryCampinas
crop_outer_slices: true
text_description: inference
transforms:
crop: null
estimate_sensitivity_maps: true
scaling_key: masked_kspace
Loading

0 comments on commit e3cca85

Please sign in to comment.