Skip to content

Commit

Permalink
Add documentation (#242)
Browse files Browse the repository at this point in the history
* Add instructions to create your model
  • Loading branch information
georgeyiasemis authored Mar 14, 2023
1 parent 3a5e2aa commit b380e2e
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 47 deletions.
3 changes: 2 additions & 1 deletion direct/algorithms/mri_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def __init__(
super().__init__()

def calculate_sensitivity_map(self, acs_mask: torch.Tensor, kspace: torch.Tensor) -> torch.Tensor:
"""Calculates sensitivity map given as input the "acs_mask" and the "k-space".
"""Calculates sensitivity map given as input the `acs_mask` and the `k-space`.
Parameters
----------
acs_mask : torch.Tensor
Expand Down
73 changes: 38 additions & 35 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __repr__(self):
repr_string = self.__class__.__name__ + "("
for transform in self.transforms:
repr_string += "\n"
repr_string += f" {transform}"
repr_string += "\n)"
repr_string += f" {transform},"
repr_string = repr_string[:-1] + "\n)"
return repr_string


Expand Down Expand Up @@ -576,13 +576,13 @@ def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Parameters
----------
sample: Dict[str, Any]
Contains key kspace_key with value a torch.Tensor of shape (coil, *spatial_dims, complex=2).
Contains key kspace_key with value a torch.Tensor of shape (coil,\*spatial_dims, complex=2).
Returns
----------
-------
sample: dict
Contains key target_key with value a torch.Tensor of shape (*spatial_dims) if `type_reconstruction` is
"rss", "complex_mod" or "sense_mod", and of shape(*spatial_dims, complex_dim=2) otherwise.
Contains key target_key with value a torch.Tensor of shape (\*spatial_dims) if `type_reconstruction` is
"rss", "complex_mod" or "sense_mod", and of shape(\*spatial_dims, complex_dim=2) otherwise.
"""
kspace_data = sample[self.kspace_key]

Expand Down Expand Up @@ -1187,6 +1187,9 @@ def __call__(self, sample):

return sample

def __repr__(self):
return self._transform.__repr__()

def __init__(self, module: Callable, toggle_dims: bool):
self._module = module
self.toggle_dims = toggle_dims
Expand All @@ -1206,7 +1209,7 @@ def __call__(self, *args, **kwargs):
WhitenData = ModuleWrapper(WhitenDataModule, toggle_dims=False)


class ToTensor:
class ToTensor(DirectTransform):
"""Transforms all np.array-like values in sample to torch.tensors."""

def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -1385,13 +1388,13 @@ def build_post_mri_transforms(
) -> object:
"""Build transforms for MRI.
- Converts input to (complex-valued) tensor.
- Adds a sampling mask if `mask_func` is defined.
- Adds coil sensitivities and / or the body coil_image
- Crops the input data if needed and masks the fully sampled k-space.
- Add a target.
- Normalize input data.
- Pads the coil dimension.
* Converts input to (complex-valued) tensor.
* Adds a sampling mask if `mask_func` is defined.
* Adds coil sensitivities and / or the body coil_image
* Crops the input data if needed and masks the fully sampled k-space.
* Add a target.
* Normalize input data.
* Pads the coil dimension.
Parameters
----------
Expand All @@ -1401,17 +1404,17 @@ def build_post_mri_transforms(
Estimate sensitivity maps using the acs region. Default: True.
sensitivity_maps_type: sensitivity_maps_type
Can be SensitivityMapType.rss_estimate, SensitivityMapType.unit or SensitivityMapType.espirit.
Will be ignored if `estimate_sensitivity_maps`==False. Default: SensitivityMapType.rss_estimate.
Will be ignored if `estimate_sensitivity_maps` is equal to False. Default: SensitivityMapType.rss_estimate.
sensitivity_maps_gaussian : float
Optional sigma for gaussian weighting of sensitivity map.
sensitivity_maps_espirit_threshold: float, optional
Threshold for the calibration matrix when `type_of_map`=="espirit". Default: 0.05.
Threshold for the calibration matrix when `type_of_map` is equal to "espirit". Default: 0.05.
sensitivity_maps_espirit_kernel_size: int, optional
Kernel size for the calibration matrix when `type_of_map`=="espirit". Default: 6.
Kernel size for the calibration matrix when `type_of_map` is equal to "espirit". Default: 6.
sensitivity_maps_espirit_crop: float, optional
Output eigenvalue cropping threshold when `type_of_map`=="espirit". Default: 0.95.
Output eigenvalue cropping threshold when `type_of_map` is equal to "espirit". Default: 0.95.
sensitivity_maps_espirit_max_iters: int, optional
Power method iterations when `type_of_map`=="espirit". Default: 30.
Power method iterations when `type_of_map` is equal to "espirit". Default: 30.
delete_acs_mask : bool
If True will delete key `acs_mask`. Default: True.
delete_kspace : bool
Expand Down Expand Up @@ -1508,13 +1511,13 @@ def build_mri_transforms(
) -> object:
"""Build transforms for MRI.
- Converts input to (complex-valued) tensor.
- Adds a sampling mask if `mask_func` is defined.
- Adds coil sensitivities and / or the body coil_image
- Crops the input data if needed and masks the fully sampled k-space.
- Add a target.
- Normalize input data.
- Pads the coil dimension.
* Converts input to (complex-valued) tensor.
* Adds a sampling mask if `mask_func` is defined.
* Adds coil sensitivities and / or the body coil_image
* Crops the input data if needed and masks the fully sampled k-space.
* Add a target.
* Normalize input data.
* Pads the coil dimension.
Parameters
----------
Expand Down Expand Up @@ -1555,17 +1558,17 @@ def build_mri_transforms(
Estimate sensitivity maps using the acs region. Default: True.
sensitivity_maps_type: sensitivity_maps_type
Can be SensitivityMapType.rss_estimate, SensitivityMapType.unit or SensitivityMapType.espirit.
Will be ignored if `estimate_sensitivity_maps`==False. Default: SensitivityMapType.rss_estimate.
Will be ignored if `estimate_sensitivity_maps` is False. Default: SensitivityMapType.rss_estimate.
sensitivity_maps_gaussian : float
Optional sigma for gaussian weighting of sensitivity map.
sensitivity_maps_espirit_threshold: float, optional
Threshold for the calibration matrix when `type_of_map`=="espirit". Default: 0.05.
sensitivity_maps_espirit_kernel_size: int, optional
Kernel size for the calibration matrix when `type_of_map`=="espirit". Default: 6.
sensitivity_maps_espirit_crop: float, optional
Output eigenvalue cropping threshold when `type_of_map`=="espirit". Default: 0.95.
sensitivity_maps_espirit_max_iters: int, optional
Power method iterations when `type_of_map`=="espirit". Default: 30.
sensitivity_maps_espirit_threshold : float, optional
Threshold for the calibration matrix when `type_of_map` is equal to "espirit". Default: 0.05.
sensitivity_maps_espirit_kernel_size : int, optional
Kernel size for the calibration matrix when `type_of_map` is equal to "espirit". Default: 6.
sensitivity_maps_espirit_crop : float, optional
Output eigenvalue cropping threshold when `type_of_map` is equal to "espirit". Default: 0.95.
sensitivity_maps_espirit_max_iters : int, optional
Power method iterations when `type_of_map` is equal to "espirit". Default: 30.
delete_acs_mask : bool
If True will delete key `acs_mask`. Default: True.
delete_kspace : bool
Expand Down
4 changes: 2 additions & 2 deletions direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,12 +845,12 @@ def crop_to_acs(acs_mask: torch.Tensor, kspace: torch.Tensor) -> torch.Tensor:
acs_mask : torch.Tensor
Autocalibration mask of shape (height, width).
kspace : torch.Tensor
K-space of shape (coil, height, width, *).
K-space of shape (coil, height, width, \*).
Returns
-------
torch.Tensor
Cropped k-space of shape (coil, height', width', *), where height' and width' are the new dimensions derived
Cropped k-space of shape (coil, height', width', \*), where height' and width' are the new dimensions derived
from the acs_mask.
"""
nonzero_idxs = torch.nonzero(acs_mask)
Expand Down
2 changes: 1 addition & 1 deletion direct/nn/jointicnet/jointicnet_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
mixed_precision: bool = False,
**models: nn.Module,
):
"""Inits :class:`JointICNetEngine."""
"""Inits :class:`JointICNetEngine`."""
super().__init__(
cfg,
model,
Expand Down
16 changes: 8 additions & 8 deletions direct/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,9 @@ def __repr__(self):
elif isinstance(v, (dict, OrderedDict)):
repr_string += f"{k}=dict(len={len(v)}), "
elif isinstance(v, list):
repr_string = f"{k}=list(len={len(v)}), "
repr_string += f"{k}=list(len={len(v)}), "
elif isinstance(v, tuple):
repr_string = f"{k}=tuple(len={len(v)}), "
repr_string += f"{k}=tuple(len={len(v)}), "
else:
repr_string += str(v) + ", "

Expand Down Expand Up @@ -494,12 +494,12 @@ def remove_keys(input_dict: Dict, keys: Union[str, List[str], Tuple[str]]) -> Di
def dict_flatten(in_dict: DictOrDictConfig, dict_out: Optional[DictOrDictConfig] = None) -> Dict[str, Any]:
"""Flattens a nested dictionary (or DictConfig) and returns a new flattened dictionary.
If a dict_out is provided, the flattened dictionary will be added to it.
If a `dict_out` is provided, the flattened dictionary will be added to it.
Parameters
----------
in_dict : DictOrDictConfig
The nested dictionary or DictConfig to flatten
The nested dictionary or DictConfig to flatten.
dict_out : Optional[DictOrDictConfig], optional
An existing dictionary to add the flattened dictionary to. Default: None.
Expand All @@ -508,12 +508,12 @@ def dict_flatten(in_dict: DictOrDictConfig, dict_out: Optional[DictOrDictConfig]
Dict[str, Any]
The flattened dictionary.
Note
----
Notes
-----
* This function only keeps the final keys, and discards the intermediate ones.
Example
-------
Examples
--------
>>> dictA = {"a": 1, "b": {"c": 2, "d": 3, "e": {"f": 4, 6: "a", 5: {"g": 6}, "l": [1, "two"]}}}
>>> dict_flatten(dictA)
{'a': 1, 'c': 2, 'd': 3, 'f': 4, 6: 'a', 'g': 6, 'l': [1, 'two']}
Expand Down
6 changes: 6 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ images such as MRIs from partially observed or noisy input data.
datasets
samplers

.. toctree::
:maxdepth: 1
:caption: Add more Models

models

.. toctree::
:maxdepth: 1
:caption: Examples
Expand Down
133 changes: 133 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
.. highlight:: shell

=====================
Adding your own model
=====================

To add a new model follow the steps below:

- Implement your custom model under :code:`direct/nn/<model_name>/<model_name>.py`. For example:

.. code-block:: python
import torch
from torch import nn
from torch.nn import functional as F
class MyMRIModel(nn.Module):
"""My custom MRI model."""
def __init__(self, param1: param1_type, ...):
"""Inits :class:`MyMRIModel`.
Parameters
----------
param1 : param1_type
...
...
"""
super().__init__()
def my_method(self, ...) -> ...:
pass
@staticmethod
def my_static_method(...) -> ...:
pass
def forward(
self,
masked_kspace: torch.Tensor,
sampling_mask: torch.Tensor,
sensitivity_map: torch.Tensor,
...
) -> torch.Tensor:
"""Computes forward pass of :class:`MyMRIModel`.
Parameters
----------
masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2).
...
Returns
-------
out_image: torch.Tensor
Output image of shape (N, height, width, complex=2).
...
"""
- Implement your custom model's engine under :code:`direct/nn/<model_name>/<model_name>_engine.py`. For example:

.. code-block:: python
from __future__ import annotations
from typing import Any, Callable, Dict, Optional, Tuple
import torch
from torch import nn
from direct.config import BaseConfig
from direct.nn.mri_models import MRIModelEngine
class MyMRIModelEngine(MRIModelEngine):
""":class:`MyMRIModel` Engine."""
def __init__(
self,
cfg: BaseConfig,
model: nn.Module,
device: str,
forward_operator: Optional[Callable] = None,
backward_operator: Optional[Callable] = None,
mixed_precision: bool = False,
**models: nn.Module,
):
"""Inits :class:`MyMRIModel`."""
super().__init__(
cfg,
model,
device,
forward_operator=forward_operator,
backward_operator=backward_operator,
mixed_precision=mixed_precision,
**models,
)
def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor | None, torch.Tensor | None]:
output_image = self.model(
masked_kspace=data["masked_kspace"],
sampling_mask=data["sampling_mask"],
sensitivity_map=data["sensitivity_map"],
...=...
)
# ΟR
output_kspace = self.model(
masked_kspace=data["masked_kspace"],
sampling_mask=data["sampling_mask"],
sensitivity_map=data["sensitivity_map"],
...=...
)
...
return output_image, output_kspace
- Implement your custom model's config under :code:`direct/nn/<model_name>/config.py`. For example:

.. code-block:: python
from dataclasses import dataclass
from direct.config.defaults import ModelConfig
@dataclass
class MyMRIModelConfig(ModelConfig):
param1: param1_type = param1_default_value
...

0 comments on commit b380e2e

Please sign in to comment.