From d472bdd5459c7e6adcfb263c84caed90dcfaf36c Mon Sep 17 00:00:00 2001 From: Jonas Teuwen <2347927+jonasteuwen@users.noreply.github.com> Date: Fri, 7 Jan 2022 15:24:26 +0100 Subject: [PATCH] Release of version 1.0.0 (#163, Closes #152) * Update documentation * Update packages * Removed unused imports * Convert `sewar` to optional dependency * Add logo * Add warning this software is not meant for medical use. * Bump version -> v1.0.0 --- .github/workflows/pylint.yml | 2 +- CODE_OF_CONDUCT.md => CODE_OF_CONDUCT.rst | 34 +- README.md | 49 - README.rst | 63 + authors.md | 9 - authors.rst | 13 + contributing.md => contributing.rst | 88 +- direct/__init__.py | 5 +- direct/checkpointer.py | 21 +- direct/cli/predict.py | 14 +- direct/common/subsample.py | 115 +- direct/common/subsample_config.py | 2 +- direct/config/defaults.py | 2 +- direct/data/bbox.py | 10 +- direct/data/datasets.py | 32 +- direct/data/fake.py | 14 +- direct/data/h5_data.py | 21 +- direct/data/lr_scheduler.py | 14 +- direct/data/mri_transforms.py | 63 +- direct/data/samplers.py | 6 +- direct/data/transforms.py | 230 +- direct/engine.py | 4 +- direct/environment.py | 28 +- direct/exceptions.py | 2 +- direct/functionals/psnr.py | 6 +- direct/functionals/ssim.py | 12 +- direct/inference.py | 30 +- direct/launch.py | 24 +- direct/nn/conv/conv.py | 12 +- direct/nn/crossdomain/crossdomain.py | 28 +- direct/nn/crossdomain/multicoil.py | 10 +- direct/nn/didn/didn.py | 34 +- direct/nn/jointicnet/jointicnet.py | 23 +- direct/nn/jointicnet/jointicnet_engine.py | 50 +- direct/nn/kikinet/kikinet.py | 28 +- direct/nn/kikinet/kikinet_engine.py | 50 +- direct/nn/lpd/lpd.py | 27 +- direct/nn/lpd/lpd_engine.py | 50 +- direct/nn/mobilenet/mobilenet.py | 16 +- direct/nn/multidomainnet/config.py | 1 - direct/nn/multidomainnet/multidomain.py | 34 +- direct/nn/multidomainnet/multidomainnet.py | 29 +- .../multidomainnet/multidomainnet_engine.py | 50 +- direct/nn/mwcnn/mwcnn.py | 31 +- direct/nn/recurrent/recurrent.py | 22 +- direct/nn/recurrentvarnet/recurrentvarnet.py | 68 +- .../recurrentvarnet/recurrentvarnet_engine.py | 50 +- direct/nn/rim/rim.py | 47 +- direct/nn/rim/rim_engine.py | 119 +- direct/nn/unet/unet_2d.py | 64 +- direct/nn/unet/unet_engine.py | 42 +- direct/nn/varnet/varnet.py | 45 +- direct/nn/varnet/varnet_engine.py | 32 +- direct/nn/xpdnet/xpdnet.py | 23 +- direct/nn/xpdnet/xpdnet_engine.py | 50 +- direct/predict.py | 1 - direct/train.py | 2 +- direct/utils/__init__.py | 66 +- direct/utils/asserts.py | 10 +- direct/utils/bbox.py | 10 +- direct/utils/communication.py | 16 +- direct/utils/dataset.py | 6 +- direct/utils/events.py | 61 +- direct/utils/io.py | 62 +- direct/utils/logging.py | 8 +- direct/utils/models.py | 2 +- direct/utils/writers.py | 10 +- docs/authors.rst | 1 + docs/conf.py | 4 +- docs/config.rst | 111 + docs/datasets.rst | 117 + docs/getting_started.rst | 30 + docs/history.rst | 2 + docs/index.rst | 48 +- docs/inference.rst | 37 + docs/installation.rst | 52 +- docs/model_zoo.rst | 175 + docs/samplers.rst | 91 + docs/training.rst | 30 + getting_started.md | 37 - install.md | 42 - installation.rst | 62 + logo/direct_logo_horizontal.svg | 4589 +++++++++++++++++ logo/direct_logo_square.svg | 1260 +++++ model_zoo.md | 44 - .../spie_radial_subsampling/plot_zoomed.py | 66 +- setup.cfg | 6 +- setup.py | 21 +- 88 files changed, 7721 insertions(+), 1306 deletions(-) rename CODE_OF_CONDUCT.md => CODE_OF_CONDUCT.rst (87%) delete mode 100644 README.md create mode 100644 README.rst delete mode 100644 authors.md create mode 100644 authors.rst rename contributing.md => contributing.rst (65%) create mode 100644 docs/authors.rst create mode 100644 docs/config.rst create mode 100644 docs/datasets.rst create mode 100644 docs/getting_started.rst create mode 100644 docs/history.rst create mode 100644 docs/inference.rst create mode 100644 docs/model_zoo.rst create mode 100644 docs/samplers.rst create mode 100644 docs/training.rst delete mode 100644 getting_started.md delete mode 100644 install.md create mode 100644 installation.rst create mode 100644 logo/direct_logo_horizontal.svg create mode 100644 logo/direct_logo_square.svg delete mode 100644 model_zoo.md diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 8c900f99..bc053f3c 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -14,7 +14,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pylint + pip install pylint sewar pip install -e ".[dev]" - name: Analysing the code with pylint run: | diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.rst similarity index 87% rename from CODE_OF_CONDUCT.md rename to CODE_OF_CONDUCT.rst index a8f797f2..5b34f5a3 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.rst @@ -1,6 +1,9 @@ -# Contributor Covenant Code of Conduct -## Our Pledge +Contributor Covenant Code of Conduct +==================================== + +Our Pledge +---------- In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and @@ -9,11 +12,13 @@ size, disability, ethnicity, sex characteristics, gender identity and expression level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. -## Our Standards +Our Standards +------------- Examples of behavior that contributes to creating a positive environment include: + * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism @@ -22,16 +27,18 @@ include: Examples of unacceptable behavior by participants include: + * The use of sexualized language or imagery and unwelcome sexual attention or - advances + advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic - address, without explicit permission + address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a - professional setting + professional setting -## Our Responsibilities +Our Responsibilities +-------------------- Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in @@ -43,7 +50,8 @@ that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. -## Scope +Scope +----- This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of @@ -52,7 +60,8 @@ address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. -## Enforcement +Enforcement +----------- Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at j.teuwen@nki.nl. All @@ -65,12 +74,11 @@ Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. -## Attribution +Attribution +----------- -This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +This Code of Conduct is adapted from the `Contributor Covenant `_\ , version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html -[homepage]: https://www.contributor-covenant.org - For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq diff --git a/README.md b/README.md deleted file mode 100644 index fd948ba3..00000000 --- a/README.md +++ /dev/null @@ -1,49 +0,0 @@ -[![tox](https://github.com/NKI-AI/direct/actions/workflows/tox.yml/badge.svg)](https://github.com/NKI-AI/direct/actions/workflows/tox.yml) -[![pylint](https://github.com/NKI-AI/direct/actions/workflows/pylint.yml/badge.svg)](https://github.com/NKI-AI/direct/actions/workflows/pylint.yml) -[![black](https://github.com/NKI-AI/direct/actions/workflows/black.yml/badge.svg)](https://github.com/NKI-AI/direct/actions/workflows/black.yml) -[![codecov](https://codecov.io/gh/NKI-AI/direct/branch/main/graph/badge.svg?token=STYAUFCKJY)](https://codecov.io/gh/NKI-AI/direct) - -# DIRECT: Deep Image REConstruction Toolkit -`DIRECT` is a Python, end-to-end pipeline for solving Inverse Problems emerging in Imaging Processing. It is built with PyTorch and stores state-of-the-art Deep Learning imaging inverse problem solvers such as denoising, dealiasing and reconstruction. By defining a base forward linear or non-linear operator, `DIRECT` can be used for training models for recovering images such as MRIs from partially observed or noisy input data. - -`DIRECT` stores inverse problem solvers such as the Learned Primal Dual algorithm, Recurrent Inference Machine and Recurrent Variational Network, which were part of the winning solution in Facebook & NYUs FastMRI challenge in 2019 and the Calgary-Campinas MRI reconstruction challenge at MIDL 2020. For a full list of the baselines currently implemented in DIRECT see [here](#baselines-and-trained-models). - -
- -
- -## Installation -See [install.md](install.md). - -## Quick Start -See [getting_started.md](getting_started.md), check out the [documentation](https://docs.aiforoncology.nl/direct). -In the [projects](projects) folder examples are given on how to train models on public datasets. - -## Baselines and trained models -- [Recurrent Variational Network (RecurrentVarNet)](https://arxiv.org/abs/2111.09639) -- [Recurrent Inference Machine (RIM)](https://www.sciencedirect.com/science/article/abs/pii/S1361841518306078) -- [End-to-end Variational Network (VarNet)](https://arxiv.org/pdf/2004.06688.pdf) -- [Learned Primal Dual Network (LDPNet)](https://arxiv.org/abs/1707.06474) -- [X-Primal Dual Network (XPDNet)](https://arxiv.org/abs/2010.07290) -- [KIKI-Net](https://pubmed.ncbi.nlm.nih.gov/29624729/) -- [U-Net](https://arxiv.org/abs/1811.08839) -- [Joint-ICNet](https://openaccess.thecvf.com/content/CVPR2021/papers/Jun_Joint_Deep_Model-Based_MR_Image_and_Coil_Sensitivity_Reconstruction_Network_CVPR_2021_paper.pdf) -- [AIRS Medical fastmri model (MultiDomainNet)](https://arxiv.org/pdf/2012.06318.pdf) - -We provide a set of baseline results and trained models in the [DIRECT Model Zoo](model_zoo.md). - -## License -DIRECT is released under the [Apache 2.0 License](LICENSE). - -## Citing DIRECT -If you use DIRECT in your own research, or want to refer to baseline results published in the - [DIRECT Model Zoo](model_zoo.md), please use the following BiBTeX entry: - -```BibTeX -@misc{DIRECTTOOLKIT, - author = {Yiasemis, George and Moriakov, Nikita and Karkalousos, Dimitrios and Caan, Matthan and Teuwen, Jonas}, - title = {DIRECT: Deep Image REConstruction Toolkit}, - howpublished = {\url{https://github.com/NKI-AI/direct}}, - year = {2021} -} -``` diff --git a/README.rst b/README.rst new file mode 100644 index 00000000..81e6ac97 --- /dev/null +++ b/README.rst @@ -0,0 +1,63 @@ +.. image:: https://github.com/NKI-AI/direct/actions/workflows/tox.yml/badge.svg + :target: https://github.com/NKI-AI/direct/actions/workflows/tox.yml + :alt: tox + +.. image:: https://github.com/NKI-AI/direct/actions/workflows/pylint.yml/badge.svg + :target: https://github.com/NKI-AI/direct/actions/workflows/pylint.yml + :alt: pylint + +.. image:: https://github.com/NKI-AI/direct/actions/workflows/black.yml/badge.svg + :target: https://github.com/NKI-AI/direct/actions/workflows/black.yml + :alt: black + +.. image:: https://codecov.io/gh/NKI-AI/direct/branch/main/graph/badge.svg?token=STYAUFCKJY + :target: https://codecov.io/gh/NKI-AI/direct + :alt: codecov + + +DIRECT: Deep Image REConstruction Toolkit +========================================= + +``DIRECT`` is a Python, end-to-end pipeline for solving Inverse Problems emerging in Imaging Processing. It is built with PyTorch and stores state-of-the-art Deep Learning imaging inverse problem solvers such as denoising, dealiasing and reconstruction. By defining a base forward linear or non-linear operator, ``DIRECT`` can be used for training models for recovering images such as MRIs from partially observed or noisy input data. +``DIRECT`` stores inverse problem solvers such as the Learned Primal Dual algorithm, Recurrent Inference Machine and Recurrent Variational Network, which were part of the winning solution in Facebook & NYUs FastMRI challenge in 2019 and the Calgary-Campinas MRI reconstruction challenge at MIDL 2020. For a full list of the baselines currently implemented in DIRECT see `here <#baselines-and-trained-models>`_. + +.. raw:: html + +
+ +
+ + + +Installation and Quick Start +---------------------------- + +Check out the `documentation `_ for installation and a quick start. + +Projects +-------- +In the `projects `_ folder baseline model configurations are provided for each project. + +Baselines and trained models +---------------------------- + +We provide a set of baseline results and trained models in the `DIRECT Model Zoo `_. Baselines and trained models include the `Recurrent Variational Network (RecurrentVarNet) `_, the `Recurrent Inference Machine (RIM) `_, the `End-to-end Variational Network (VarNet) `_, the `Learned Primal Dual Network (LDPNet) `_, the `X-Primal Dual Network (XPDNet) `_, the `KIKI-Net `_, the `U-Net `_, the `Joint-ICNet `_, and the `AIRS Medical fastmri model (MultiDomainNet) `_. + +License and usage +----------------- + +DIRECT is not intended for clinical use. DIRECT is released under the `Apache 2.0 License `_. + +Citing DIRECT +------------- + +If you use DIRECT in your own research, or want to refer to baseline results published in the `DIRECT Model Zoo `_\ , please use the following BiBTeX entry: + +.. code-block:: BibTeX + + @misc{DIRECTTOOLKIT, + author = {Yiasemis, George and Moriakov, Nikita and Karkalousos, Dimitrios and Caan, Matthan and Teuwen, Jonas}, + title = {DIRECT: Deep Image REConstruction Toolkit}, + howpublished = {\url{https://github.com/NKI-AI/direct}}, + year = {2021} + } diff --git a/authors.md b/authors.md deleted file mode 100644 index a27ee4b3..00000000 --- a/authors.md +++ /dev/null @@ -1,9 +0,0 @@ -# Credits -## Development Lead -* Jonas Teuwen -* George Yiasemis - -## Contributors -* Nikita Moriakov -* Dimitrios Karkalousos -* Matthan W.A. Caan diff --git a/authors.rst b/authors.rst new file mode 100644 index 00000000..cde98004 --- /dev/null +++ b/authors.rst @@ -0,0 +1,13 @@ +Credits +======= + +Development Lead +---------------- +* Jonas Teuwen j.teuwen@nki.nl +* George Yiasemis g.yiasemis@nki.nl + +Contributors +------------ +* Nikita Moriakov n.moriakov@nki.nl +* Dimitrios Karkalousos d.karkalousos@amsterdamumc.nl +* Matthan W.A. Caan m.w.a.caan@amsterdamumc.nl diff --git a/contributing.md b/contributing.rst similarity index 65% rename from contributing.md rename to contributing.rst index 16d35491..c44ad74b 100644 --- a/contributing.md +++ b/contributing.rst @@ -1,68 +1,88 @@ -# Contributing + +Contributing +============ + Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given. You can contribute in many ways: -## Types of Contributions -### Report Bugs +Types of Contributions +---------------------- + +Report Bugs +^^^^^^^^^^^ + Report bugs at https://github.com/NKI-AI/direct/issues. If you are reporting a bug, please include: + * Your operating system name and version. * Any details about your local setup that might be helpful in troubleshooting. * Detailed steps to reproduce the bug. +Fix Bugs +^^^^^^^^ -### Fix Bugs Look through the GitHub issues for bugs. Anything tagged with "bug" and "help wanted" is open to whoever wants to implement it. +Implement Features +^^^^^^^^^^^^^^^^^^ -### Implement Features Look through the GitHub issues for features. Anything tagged with "enhancement" and "help wanted" is open to whoever wants to implement it. +Write Documentation +^^^^^^^^^^^^^^^^^^^ -### Write Documentation DIRECT could always use more documentation, whether as part of the official DIRECT docs, in docstrings, or even on the web in blog posts, articles, and such. -### Submit Feedback +Submit Feedback +^^^^^^^^^^^^^^^ + The best way to send feedback is to file an issue at https://github.com/NKI-AI/direct/issues. If you are proposing a feature: + * Explain in detail how it would work. * Keep the scope as narrow as possible, to make it easier to implement. * Remember that this is a volunteer-driven project, and that contributions are welcome :) -### Get Started! +Get Started! +^^^^^^^^^^^^ + +Ready to contribute? Here's how to set up ``direct`` for local development. -Ready to contribute? Here's how to set up `direct` for local development. -1. Fork the `direct` repo on GitHub. -2. Clone your fork locally:: +#. Fork the ``direct`` repo on GitHub. +#. + Clone your fork locally: $ git clone git@github.com:your_name_here/direct.git -3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: +#. + Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development: $ mkvirtualenv direct $ cd direct/ $ python setup.py develop -4. Create a branch for local development:: +#. + Create a branch for local development: $ git checkout -b name-of-your-bugfix-or-feature Now you can make your changes locally. -5. When you're done making changes, check that your changes pass flake8 and the - tests, including testing other Python versions with tox:: +#. + When you're done making changes, check that your changes pass flake8 and the + tests, including testing other Python versions with tox: $ flake8 direct tests $ python setup.py test or pytest @@ -70,37 +90,45 @@ Ready to contribute? Here's how to set up `direct` for local development. To get flake8 and tox, just pip install them into your virtualenv. -6. Commit your changes and push your branch to GitHub:: +#. + Commit your changes and push your branch to GitHub: $ git add . $ git commit -m "Your detailed description of your changes." $ git push origin name-of-your-bugfix-or-feature -7. Submit a pull request through the GitHub website. +#. + Submit a pull request through the GitHub website. + +Pull Request Guidelines +^^^^^^^^^^^^^^^^^^^^^^^ -### Pull Request Guidelines Before you submit a pull request, check that it meets these guidelines: -1. The pull request should include tests. -2. If the pull request adds functionality, the docs should be updated. Put + +#. The pull request should include tests. +#. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.rst. -3. The pull request should work for Python 3.8 and 3.9 and for PyPy. Check Github actions and see that all tests pass. +#. The pull request should work for Python 3.8 and 3.9 and for PyPy. Check Github actions and see that all tests pass. + +Tests +^^^^^ -### Tests To run tests: -`pytest` +``pytest`` +Deploying +^^^^^^^^^ -### Deploying A reminder for the maintainers on how to deploy. Make sure all your changes are committed (including an entry in HISTORY.md). -Then run:: +Then run: + +.. code-block:: -``` -bump2version patch # possible: major / minor / patch -git push -git push --tags -``` + bump2version patch # possible: major / minor / patch + git push + git push --tags Travis will then deploy to PyPI if tests pass. diff --git a/direct/__init__.py b/direct/__init__.py index c39afd1b..7c46d240 100644 --- a/direct/__init__.py +++ b/direct/__init__.py @@ -1,6 +1,5 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -__author__ = """Jonas Teuwen""" -__email__ = "j.teuwen@nki.nl" -__version__ = "1.0.0-dev0" +__author__ = """direct contributors""" +__version__ = "1.0.0" diff --git a/direct/checkpointer.py b/direct/checkpointer.py index 44d1cfd4..760dc17a 100644 --- a/direct/checkpointer.py +++ b/direct/checkpointer.py @@ -83,15 +83,16 @@ def load( last_model_text_path = self.save_directory / "last_model.txt" self.logger.info("Attempting to load latest model.") if last_model_text_path.exists(): - with open(pathlib.Path(last_model_text_path), "r") as f: + with open(pathlib.Path(last_model_text_path), "r", encoding="utf-8") as f: iteration = int(f.readline()) - self.logger.info(f"Loading last saved iteration: {iteration}.") + self.logger.info("Loading last saved iteration: %s", iteration) else: self.logger.info( - f"Latest model not found. Perhaps `last_model.txt` (path = {last_model_text_path}) " - f"is missing? You can try to set an explicit iteration number, or create this file if " - f"you believe this is an error. Will not load any model." + "Latest model not found. Perhaps `last_model.txt` (path = %s) " + "is missing? You can try to set an explicit iteration number, or create this file if " + "you believe this is an error. Will not load any model.", + last_model_text_path, ) return {} @@ -114,11 +115,11 @@ def load_from_path( Parameters ---------- - checkpoint_path : Path or str + checkpoint_path: Path or str Path to checkpoint, either a path to a file or a path to a URL where the file can be downloaded - checkpointable_objects : dict + checkpointable_objects: dict Dictionary mapping names to nn.Module's - only_models : bool + only_models: bool If true will only load the models and no other objects in the checkpoint Returns @@ -194,7 +195,7 @@ def save(self, iteration: int, **kwargs: Dict[str, str]) -> None: torch.save(data, f) # noinspection PyTypeChecker - with open(self.save_directory / "last_model.txt", "w") as f: # type: ignore + with open(self.save_directory / "last_model.txt", "w", encoding="utf-8") as f: # type: ignore f.write(str(iteration)) # type: ignore def _load_checkpoint(self, checkpoint_path: PathOrString) -> Dict: @@ -203,7 +204,7 @@ def _load_checkpoint(self, checkpoint_path: PathOrString) -> Dict: Parameters ---------- - checkpoint_path : Path or str + checkpoint_path: Path or str Path to checkpoint, either a path to a file or a path to a URL where the file can be downloaded Returns ------- diff --git a/direct/cli/predict.py b/direct/cli/predict.py index 137f855c..a055e819 100644 --- a/direct/cli/predict.py +++ b/direct/cli/predict.py @@ -12,18 +12,18 @@ def register_parser(parser: argparse._SubParsersAction): """Register wsi commands to a root parser.""" - epilog = f""" + epilog = """ Examples: --------- Run on single machine: - $ direct predict --cfg --checkpoint \ - --num-gpus [ --cfg .yaml --other-flags ] + $ direct predict --cfg --checkpoint + --num-gpus [--other-flag-args ] Run on multiple machines: - (machine0)$ direct predict --cfg --checkpoint \ - --machine-rank 0 --num-machines 2 --dist-url [--other-flags] - (machine1)$ direct predict --cfg --checkpoint \ - --machine-rank 1 --num-machines 2 --dist-url [--other-flags] + (machine0)$ direct predict --cfg --checkpoint + --machine-rank 0 --num-machines 2 --dist-url [--other-flag-args ] + (machine1)$ direct predict --cfg --checkpoint + --machine-rank 1 --num-machines 2 --dist-url [--other-flag-args ] """ common_parser = Args(add_help=False) predict_parser = parser.add_parser( diff --git a/direct/common/subsample.py b/direct/common/subsample.py index 91b90acb..f9b8afe2 100644 --- a/direct/common/subsample.py +++ b/direct/common/subsample.py @@ -48,14 +48,14 @@ def __init__( """ Parameters ---------- - center_fractions : List([float]) + center_fractions: List([float]) Fraction of low-frequency columns to be retained. If multiple values are provided, then one of these numbers is chosen uniformly each time. If uniform_range is True, then two values should be given. - accelerations : List([int]) + accelerations: List([int]) Amount of under-sampling_mask. An acceleration of 4 retains 25% of the k-space, the method is given by mask_type. Has to be the same length as center_fractions if uniform_range is True. - uniform_range : bool + uniform_range: bool If True then an acceleration will be uniformly sampled between the two values. """ if center_fractions is not None: @@ -94,11 +94,11 @@ def __call__(self, data, seed=None, return_acs=False): """ Parameters ---------- - data : object - seed : int (optional) + data: object + seed: int (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. - return_acs : bool + return_acs: bool Return the autocalibration signal region as a mask. Returns @@ -125,16 +125,13 @@ def __init__( def mask_func(self, shape, return_acs=False, seed=None): """ - Create vertical line mask. - Code from: https://github.com/facebookresearch/fastMRI/blob/master/common/subsample.py + Creates vertical line mask. The mask selects a subset of columns from the input k-space data. If the k-space data has N columns, the mask picks out: - 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to - low-frequencies - 2. The other columns are selected uniformly at random with a probability equal to: - prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). - This ensures that the expected number of columns selected is equal to (N / acceleration) + + #. N_low_freqs = (N * center_fraction) columns in the center corresponding to low-frequencies. + #. The other columns are selected uniformly at random with a probability equal to: prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). This ensures that the expected number of columns selected is equal to (N / acceleration) It is possible to use multiple center_fractions and accelerations, in which case one possible (center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is @@ -146,19 +143,19 @@ def mask_func(self, shape, return_acs=False, seed=None): Parameters ---------- - - shape : iterable[int] + shape: iterable[int] The shape of the mask to be created. The shape should at least 3 dimensions. Samples are drawn along the second last dimension. - seed : int (optional) + seed: int (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. - return_acs : bool + return_acs: bool Return the autocalibration signal region as a mask. Returns ------- - torch.Tensor : the sampling mask + mask: torch.Tensor + The sampling mask. """ if len(shape) < 3: @@ -208,42 +205,36 @@ def __init__( def mask_func(self, shape, return_acs=False, seed=None): """ - Create equispaced vertical line mask. - Code from: https://github.com/facebookresearch/fastMRI/blob/master/common/subsample.py - - FastMRIEquispacedMaskFunc creates a sub-sampling mask of a given shape. - The mask selects a subset of columns from the input k-space data. If the - k-space data has N columns, the mask picks out: - 1. N_low_freqs = (N * center_fraction) columns in the center - corresponding tovlow-frequencies. - 2. The other columns are selected with equal spacing at a proportion - that reaches the desired acceleration rate taking into consideration - the number of low frequencies. This ensures that the expected number - of columns selected is equal to (N / acceleration) - It is possible to use multiple center_fractions and accelerations, in which - case one possible (center_fraction, acceleration) is chosen uniformly at - random each time the EquispacedMaskFunc object is called. - - Note that this function may not give equispaced samples (documented in - https://github.com/facebookresearch/fastMRI/issues/54), which will require - modifications to standard GRAPPA approaches. Nonetheless, this aspect of - the function has been preserved to match the public multicoil data. + Creates equispaced vertical line mask. + + FastMRIEquispacedMaskFunc creates a sub-sampling mask of a given shape. The mask selects a subset of columns + from the input k-space data. If the k-space data has N columns, the mask picks out: + + #. N_low_freqs = (N * center_fraction) columns in the center corresponding to low-frequencies. + #. The other columns are selected with equal spacing at a proportion that reaches the desired acceleration rate taking into consideration the number of low frequencies. This ensures that the expected number of columns selected is equal to (N / acceleration). + + It is possible to use multiple center_fractions and accelerations, in which case one possible + (center_fraction, acceleration) is chosen uniformly at random each time the EquispacedMaskFunc object is called. + + Note that this function may not give equispaced samples (documented in https://github.com/facebookresearch/fastMRI/issues/54), + which will require modifications to standard GRAPPA approaches. Nonetheless, this aspect of the function has + been preserved to match the public multicoil data. Parameters ---------- - - shape : iterable[int] + shape: iterable[int] The shape of the mask to be created. The shape should at least 3 dimensions. Samples are drawn along the second last dimension. - seed : int (optional) + seed: int (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. - return_acs : bool + return_acs: bool Return the autocalibration signal region as a mask. Returns ------- - torch.Tensor : the sampling mask + mask: torch.Tensor + The sampling mask. """ if len(shape) < 3: @@ -324,18 +315,19 @@ def mask_func(self, shape, return_acs=False, seed=None): Parameters ---------- - shape : iterable[int] + shape: iterable[int] The shape of the mask to be created. The shape should at least 3 dimensions. Samples are drawn along the second last dimension. - seed : int (optional) + seed: int (optional) Seed for the random number generator. Setting the seed ensures the same mask is generated each time for the same shape. - return_acs : bool + return_acs: bool Return the autocalibration signal region as a mask. Returns ------- - torch.Tensor : the sampling mask + mask: torch.Tensor + The sampling mask. """ shape = tuple(shape)[:-1] @@ -391,10 +383,7 @@ class CIRCUSMaskFunc(BaseMaskFunc): References ---------- - .. [1] Liu J, Saloner D. Accelerated MRI with CIRcular Cartesian UnderSampling (CIRCUS): - a variable density Cartesian sampling strategy for compressed sensing and parallel imaging. - Quant Imaging Med Surg. 2014 Feb;4(1):57-67. doi: 10.3978/j.issn.2223-4292.2014.02.01. - PMID: 24649436; PMCID: PMC3947985. + .. [1] Liu J, Saloner D. Accelerated MRI with CIRcular Cartesian UnderSampling (CIRCUS): a variable density Cartesian sampling strategy for compressed sensing and parallel imaging. Quant Imaging Med Surg. 2014 Feb;4(1):57-67. doi: 10.3978/j.issn.2223-4292.2014.02.01. PMID: 24649436; PMCID: PMC3947985. """ @@ -422,18 +411,18 @@ def get_square_ordered_idxs(square_side_size: int, square_id: int) -> Tuple[Tupl """ Returns ordered (clockwise) indices of a sub-square of a square matrix. - Parameters: - ----------- - square_side_size: int - Square side size. Dim of array. - square_id: int - Number of sub-square. Can be 0, ..., square_side_size // 2. - - Returns: - -------- - ordered_idxs: List of tuples. - Indices of each point that belongs to the square_id-th sub-square - starting from top-left point clockwise. + Parameters + ---------- + square_side_size: int + Square side size. Dim of array. + square_id: int + Number of sub-square. Can be 0, ..., square_side_size // 2. + + Returns + ------- + ordered_idxs: List of tuples. + Indices of each point that belongs to the square_id-th sub-square + starting from top-left point clockwise. """ assert square_id in range(square_side_size // 2) diff --git a/direct/common/subsample_config.py b/direct/common/subsample_config.py index 4585cacc..ed49ad85 100644 --- a/direct/common/subsample_config.py +++ b/direct/common/subsample_config.py @@ -1,7 +1,7 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union +from typing import Optional, Tuple from omegaconf import MISSING diff --git a/direct/config/defaults.py b/direct/config/defaults.py index 45d9a323..a807c5eb 100644 --- a/direct/config/defaults.py +++ b/direct/config/defaults.py @@ -2,7 +2,7 @@ # Copyright (c) DIRECT Contributors from dataclasses import dataclass, field -from typing import Any, List, Optional, Union +from typing import Any, List, Optional from omegaconf import MISSING diff --git a/direct/data/bbox.py b/direct/data/bbox.py index ad259eba..e57f0057 100644 --- a/direct/data/bbox.py +++ b/direct/data/bbox.py @@ -14,12 +14,12 @@ def crop_to_bbox( Parameters ---------- - data : np.ndarray or torch.tensor + data: np.ndarray or torch.tensor nD array or torch tensor. - bbox : list or tuple + bbox: list or tuple bbox of the form (coordinates, size), for instance (4, 4, 2, 1) is a patch starting at row 4, col 4 with height 2 and width 1. - pad_value : number + pad_value: number if bounding box would be out of the image, this is value the patch will be padded with. Returns @@ -75,8 +75,8 @@ def crop_to_largest( Parameters ---------- - data : List[Union[np.ndarray, torch.Tensor]] - pad_value : int + data: List[Union[np.ndarray, torch.Tensor]] + pad_value: int Returns ------- diff --git a/direct/data/datasets.py b/direct/data/datasets.py index 431c514d..24fd1700 100644 --- a/direct/data/datasets.py +++ b/direct/data/datasets.py @@ -414,9 +414,9 @@ class ConcatDataset(Dataset): This class is useful to assemble different existing datasets. - Arguments - --------- - datasets : sequence + Parameters + ---------- + datasets: sequence List of datasets to be concatenated From pytorch 1.5.1: torch.utils.data.ConcatDataset @@ -470,20 +470,20 @@ def build_dataset( Parameters ---------- - name : str + name: str Name of dataset class (without `Dataset`) in direct.data.datasets. - root : pathlib.Path + root: pathlib.Path Root path to the data for the dataset class. - filenames_filter : List + filenames_filter: List List of filenames to include in the dataset, should be the same as the ones that can be derived from a glob on the root. If set, will skip searching for files in the root. - sensitivity_maps : pathlib.Path + sensitivity_maps: pathlib.Path Path to sensitivity maps. - transforms : object + transforms: object Transformation object - text_description : str + text_description: str Description of dataset, can be used for logging. - kspace_context : int + kspace_context: int If set, output will be of shape -kspace_context:kspace_context. Returns @@ -492,9 +492,9 @@ def build_dataset( """ # TODO: Maybe only **kwargs are fine. - logger.info(f"Building dataset for: {name}.") + logger.info("Building dataset for: %s", name) dataset_class: Callable = str_to_class("direct.data.datasets", name + "Dataset") - logger.debug(f"Dataset class: {dataset_class}.") + logger.debug("Dataset class: %s", dataset_class) dataset = dataset_class( root=root, filenames_filter=filenames_filter, @@ -505,7 +505,7 @@ def build_dataset( **kwargs, ) - logger.debug(f"Dataset:\n{dataset}") + logger.debug("Dataset: %s", str(dataset)) return dataset @@ -522,17 +522,17 @@ def build_dataset_from_input( """ Parameters ---------- - transforms : object, Callable + transforms: object, Callable Transformation object. dataset_config: Dataset configuration file initial_images: pathlib.Path Path to initial_images. initial_kspaces: pathlib.Path Path to initial kspace images. - filenames_filter : List + filenames_filter: List List of filenames to include in the dataset, should be the same as the ones that can be derived from a glob on the root. If set, will skip searching for files in the root. - data_root : pathlib.Path + data_root: pathlib.Path Root path to the data for the dataset class. pass_dictionaries: diff --git a/direct/data/fake.py b/direct/data/fake.py index 0a5231d2..160bcdd6 100644 --- a/direct/data/fake.py +++ b/direct/data/fake.py @@ -1,11 +1,9 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors import logging -import os import pathlib from typing import Dict, List, Optional, Tuple, Union -import h5py import numpy as np from sklearn.datasets import make_blobs @@ -25,8 +23,8 @@ def __init__( ) -> None: """ - Parameters: - ----------- + Parameters + ---------- ndim: int blobs_n_samples: Optional[int], default is None. blobs_cluster_std: Optional[float], default is None. @@ -47,8 +45,8 @@ def get_kspace( num_coils: int, ) -> np.array: """ - Parameters: - ----------- + Parameters + ---------- spatial_shape: List of ints or tuple of ints. num_coils: int """ @@ -165,8 +163,8 @@ def __call__( """ Returns (and saves if save_as_h5 is True) fake mri samples in the form of gaussian blobs. - Parameters: - ----------- + Parameters + ---------- sample_size: int Size of the samples. num_coils: int diff --git a/direct/data/h5_data.py b/direct/data/h5_data.py index 40255d82..7a2de5c1 100644 --- a/direct/data/h5_data.py +++ b/direct/data/h5_data.py @@ -3,7 +3,6 @@ import logging import pathlib import re -import sys from typing import Any, Dict, List, Optional, Tuple, Union import h5py @@ -42,27 +41,27 @@ def __init__( Parameters ---------- - root : pathlib.Path + root: pathlib.Path Root directory to data. - filenames_filter : List + filenames_filter: List List of filenames to include in the dataset, should be the same as the ones that can be derived from a glob on the root. If set, will skip searching for files in the root. - regex_filter : str + regex_filter: str Regular expression filter on the absolute filename. Will be applied after any filenames filter. - metadata : dict + metadata: dict If given, this dictionary will be passed to the output transform. - sensitivity_maps : [pathlib.Path, None] + sensitivity_maps: [pathlib.Path, None] Path to sensitivity maps, or None. - extra_keys : Tuple + extra_keys: Tuple Add extra keys in h5 file to output. - pass_attrs : bool + pass_attrs: bool Pass the attributes saved in the h5 file. - text_description : str + text_description: str Description of dataset, can be useful for logging. - pass_dictionaries : dict + pass_dictionaries: dict Pass a dictionary of dictionaries, e.g. if {"name": {"filename_0": val}}, then to `filename_0`s sample dict, a key with name `name` and value `val` will be added. - pass_h5s : dict + pass_h5s: dict Pass a dictionary of paths. If {"name": path} is given then to the sample of `filename` the same slice of path / filename will be added to the sample dictionary and will be asigned key `name`. This can first instance be convenient when you want to pass sensitivity maps as well. So for instance: diff --git a/direct/data/lr_scheduler.py b/direct/data/lr_scheduler.py index 5bd22f60..964e251a 100644 --- a/direct/data/lr_scheduler.py +++ b/direct/data/lr_scheduler.py @@ -25,7 +25,7 @@ # MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it. -class LRScheduler(torch.optim.lr_scheduler._LRScheduler): # noqa +class LRScheduler(torch.optim.lr_scheduler._LRScheduler): # pylint: disable=protected-access def __init__(self, optimizer, last_epoch=-1, verbose=False): super().__init__(optimizer, last_epoch, verbose) self.logger = logging.getLogger(type(self).__name__) @@ -40,7 +40,7 @@ def state_dict(self): return state_dict -class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): +class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): # pylint: disable=protected-access def __init__( self, optimizer: torch.optim.Optimizer, @@ -79,7 +79,7 @@ def _compute_values(self) -> List[float]: return self.get_lr() -class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): +class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): # pylint: disable=protected-access def __init__( self, optimizer: torch.optim.Optimizer, @@ -124,13 +124,13 @@ def _get_warmup_factor_at_iter(method: str, iter: int, warmup_iters: int, warmup Parameters ---------- - method : str + method: str Warmup method; either "constant" or "linear". - iter : int + iter: int Iteration at which to calculate the warmup factor. - warmup_iters : int + warmup_iters: int The length of the warmup phases. - warmup_factor : float + warmup_factor: float The base warmup factor (the meaning changes according to the method used). Returns diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index ea1e537c..c5a98cf3 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -120,20 +120,20 @@ def __init__( """ Parameters ---------- - crop : tuple or None + crop: tuple or None Size to crop input_image to. - mask_func : direct.common.subsample.MaskFunc + mask_func: direct.common.subsample.MaskFunc A function which creates a mask of the appropriate shape. - use_seed : bool + use_seed: bool If true, a pseudo-random number based on the filename is computed so that every slice of the volume get the same mask every time. - forward_operator : callable + forward_operator: callable The __call__ operator, e.g. some form of FFT (centered or uncentered). - backward_operator : callable + backward_operator: callable The backward operator, e.g. some form of inverse FFT (centered or uncentered). - image_space_center_crop : bool + image_space_center_crop: bool If set, the crop in the data will be taken in the center - random_crop_sampler_type : str + random_crop_sampler_type: str If "uniform" the random cropping will be done by uniformly sampling `crop`, as opposed to `gaussian` which will sample from a gaussian distribution. """ @@ -242,16 +242,16 @@ def __call__(self, sample, coil_dim=0, spatial_dims=(1, 2)): """ Parameters ---------- - sample : dict + sample: dict Contains key kspace_key with value a torch.Tensor of shape (coil, *spatial_dims, complex=2). - coil_dim : int + coil_dim: int Coil dimension. Default: 0. - spatial_dims : (int, int) + spatial_dims: (int, int) Spatial dimensions corresponding to (height, width). Default: (1, 2). Returns ---------- - sample : dict + sample: dict Contains key target_key with value a torch.Tensor of shape (*spatial_dims) or (*spatial_dims) if type_reconstruction is 'rss'. """ @@ -359,14 +359,14 @@ def __call__(self, sample, coil_dim=0): Parameters ---------- - sample : dict + sample: dict Must contain key matching kspace_key with value a (complex) torch.Tensor of shape (coil, height, width, complex=2). - coil_dim : int + coil_dim: int Coil dimension. Default: 0. Returns ---------- - sample : dict + sample: dict """ if self.type_of_map == "unit": @@ -420,7 +420,7 @@ def __init__(self, pad_coils: Optional[int] = None, key: str = "masked_kspace"): """ Parameters ---------- - pad_coils : int + pad_coils: int Number of coils to pad to. key: tuple Key to pad in sample @@ -467,10 +467,10 @@ def __init__(self, normalize_key="masked_kspace", percentile=0.99): Parameters ---------- - normalize_key : str + normalize_key: str Key name to compute the data for. If the maximum has to be computed on the ACS, ensure the reconstruction on the ACS is available (typically `body_coil_image`). - percentile : float or None + percentile: float or None Rescale data with the given percentile. If None, the division is done by the maximum. """ super().__init__() @@ -609,7 +609,7 @@ def __call__(self, sample): if "scaling_factor" in sample: sample["scaling_factor"] = torch.tensor(sample["scaling_factor"]).float() if "loglikelihood_scaling" in sample: - # Shape : (coil, ) + # Shape: (coil, ) sample["loglikelihood_scaling"] = torch.from_numpy(np.asarray(sample["loglikelihood_scaling"])).float() return sample @@ -641,26 +641,27 @@ def build_mri_transforms( Parameters ---------- - backward_operator : callable - forward_operator : callable - mask_func : callable or none - crop : int or none - crop_type : str or None + backward_operator: callable + forward_operator: callable + mask_func: callable or none + crop: int or none + crop_type: str or None Type of cropping, either "gaussian" or "uniform". - image_center_crop : bool - estimate_sensitivity_maps : bool - estimate_body_coil_image : bool - sensitivity_maps_gaussian : float + image_center_crop: bool + estimate_sensitivity_maps: bool + estimate_body_coil_image: bool + sensitivity_maps_gaussian: float Optional sigma for gaussian weighting of sensitivity map. - pad_coils : int + pad_coils: int Number of coils to pad data to. - scaling_key : str + scaling_key: str Key to use to compute scaling factor for. - use_seed : bool + use_seed: bool Returns ------- - object : a transformation object. + object: Callable + A transformation object. """ # TODO: Use seed diff --git a/direct/data/samplers.py b/direct/data/samplers.py index a80a377a..b00e3ceb 100644 --- a/direct/data/samplers.py +++ b/direct/data/samplers.py @@ -38,11 +38,11 @@ def __init__( """ Parameters ---------- - size : int + size: int Size of underlying dataset. - shuffle : bool + shuffle: bool If true, the indices will be shuffled. - seed : int + seed: int Initial seed of the shuffle, must be the same across all workers! """ self._size = size diff --git a/direct/data/transforms.py b/direct/data/transforms.py index 2623b753..9efd7593 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -22,11 +22,12 @@ def to_tensor(data: np.ndarray) -> torch.Tensor: Parameters ---------- - data : np.ndarray + data: np.ndarray Returns ------- torch.Tensor + """ if np.iscomplexobj(data): data = np.stack((data.real, data.imag), axis=-1) @@ -38,17 +39,18 @@ def to_tensor(data: np.ndarray) -> torch.Tensor: def verify_fft_dtype_possible(data: torch.Tensor, dims: Tuple[int, ...]) -> bool: """ - Fft and ifft can only be performed on GPU in float16 if the shapes are powers of 2. + fft and ifft can only be performed on GPU in float16 if the shapes are powers of 2. This function verifies if this is the case. Parameters ---------- - data : torch.Tensor - dims : tuple + data: torch.Tensor + dims: tuple Returns ------- bool + """ is_complex64 = data.dtype == torch.complex64 is_complex32_and_power_of_two = (data.dtype == torch.float32) and all( @@ -67,8 +69,14 @@ def view_as_complex(data): Parameters ---------- - data : torch.Tensor - with torch.dtype torch.float64 and torch.float32 + data: torch.Tensor + Input data with torch.dtype torch.float64 and torch.float32 with complex axis (last) of dimension 2 + and of shape (N, \*, 2). + + Returns + ------- + complex_valued_data: torch.Tensor + Output complex-valued data of shape (N, \*) with complex torch.dtype. """ return torch.view_as_complex(data) @@ -83,8 +91,14 @@ def view_as_real(data): Parameters ---------- - data : torch.Tensor - with complex torch.dtype + data: torch.Tensor + Input data with complex torch.dtype of shape (N, \*). + + Returns + ------- + real_valued_data: torch.Tensor + Output real-valued data of shape (N, \*, 2). + """ return torch.view_as_real(data) @@ -104,19 +118,22 @@ def fft2( Parameters ---------- - data : torch.Tensor - Complex-valued input tensor. Should be of shape (*, 2) and dim is in *. - dim : tuple, list or int + data: torch.Tensor + Complex-valued input tensor. Should be of shape (\*, 2) and dim is in \*. + dim: tuple, list or int Dimensions over which to compute. Should be positive. Negative indexing not supported Default is (1, 2), corresponding to ('height', 'width'). - centered : bool + centered: bool Whether to apply a centered fft (center of kspace is in the center versus in the corners). For FastMRI dataset this has to be true and for the Calgary-Campinas dataset false. - normalized : bool - Whether to normalize the ifft. For the FastMRI this has to be true and for the Calgary-Campinas dataset false. + normalized: bool + Whether to normalize the fft. For the FastMRI this has to be true and for the Calgary-Campinas dataset false. + Returns ------- - torch.Tensor: the fft of the data. + output_data: torch.Tensor + The Fast Fourier transform of the data. + """ if not all((_ >= 0 and isinstance(_, int)) for _ in dim): raise TypeError( @@ -160,19 +177,22 @@ def ifft2( Parameters ---------- - data : torch.Tensor - Complex-valued input tensor. Should be of shape (*, 2) and dim is in *. - dim : tuple, list or int + data: torch.Tensor + Complex-valued input tensor. Should be of shape (\*, 2) and dim is in \*. + dim: tuple, list or int Dimensions over which to compute. Should be positive. Negative indexing not supported - Default is (1, 2), corresponding to ('height', 'width'). - centered : bool + Default is (1, 2), corresponding to ( 'height', 'width'). + centered: bool Whether to apply a centered ifft (center of kspace is in the center versus in the corners). For FastMRI dataset this has to be true and for the Calgary-Campinas dataset false. - normalized : bool + normalized: bool Whether to normalize the ifft. For the FastMRI this has to be true and for the Calgary-Campinas dataset false. + Returns ------- - torch.Tensor: the ifft of the data. + output_data: torch.Tensor + The Inverse Fast Fourier transform of the data. + """ if not all((_ >= 0 and isinstance(_, int)) for _ in dim): raise TypeError( @@ -207,8 +227,8 @@ def safe_divide(input_tensor: torch.Tensor, other_tensor: torch.Tensor) -> torch Parameters ---------- - input_tensor : torch.Tensor - other_tensor : torch.Tensor + input_tensor: torch.Tensor + other_tensor: torch.Tensor Returns ------- @@ -230,11 +250,13 @@ def modulus(data: torch.Tensor) -> torch.Tensor: Parameters ---------- - data : torch.Tensor + data: torch.Tensor Returns ------- - torch.Tensor: modulus of data. + output_data: torch.Tensor + Modulus of data. + """ # TODO: fix to specify dim of complex axis or make it work with complex_last=True. @@ -246,15 +268,16 @@ def modulus(data: torch.Tensor) -> torch.Tensor: def modulus_if_complex(data: torch.Tensor) -> torch.Tensor: """ - Compute modulus if complex-valued. + Compute modulus if complex tensor (has complex axis). Parameters ---------- - data : torch.Tensor + data: torch.Tensor Returns ------- torch.Tensor + """ if is_complex_data(data, complex_last=False): return modulus(data) @@ -270,13 +293,14 @@ def roll( Similar to numpy roll but applies to pytorch tensors. Parameters ---------- - data : torch.Tensor + data: torch.Tensor shift: tuple, int - dims : tuple, list or int + dims: tuple, list or int Returns ------- torch.Tensor + """ if isinstance(shift, (tuple, list)) and isinstance(dims, (tuple, list)): if len(shift) != len(dims): @@ -300,8 +324,8 @@ def fftshift(data: torch.Tensor, dim: Tuple[int, ...] = None) -> torch.Tensor: Parameters ---------- - data : torch.Tensor - dim : tuple, list or int + data: torch.Tensor + dim: tuple, list or int Returns ------- @@ -324,12 +348,13 @@ def ifftshift(data: torch.Tensor, dim: Tuple[Union[str, int], ...] = None) -> to Parameters ---------- - data : torch.Tensor - dim : tuple, list or int + data: torch.Tensor + dim: tuple, list or int Returns ------- torch.Tensor + """ if dim is None: dim = tuple(range(data.dim())) @@ -347,19 +372,18 @@ def complex_multiplication(input_tensor: torch.Tensor, other_tensor: torch.Tenso Parameters ---------- - input_tensor : torch.Tensor + input_tensor: torch.Tensor Input data - other_tensor : torch.Tensor + other_tensor: torch.Tensor Input data Returns ------- torch.Tensor + """ assert_complex(input_tensor, complex_last=True) assert_complex(other_tensor, complex_last=True) - # multiplication = torch.view_as_complex(x) * torch.view_as_complex(y) - # return torch.view_as_real(multiplication) complex_index = -1 @@ -383,14 +407,15 @@ def _complex_matrix_multiplication(input_tensor, other_tensor, mult_func): Parameters ---------- - input_tensor : torch.Tensor - other_tensor : torch.Tensor - mult_func : Callable + input_tensor: torch.Tensor + other_tensor: torch.Tensor + mult_func: Callable Multiplication function e.g. torch.bmm or torch.mm Returns ------- torch.Tensor + """ if not input_tensor.is_complex() or not other_tensor.is_complex(): raise ValueError("Both input_tensor and other_tensor have to be complex-valued torch tensors.") @@ -411,12 +436,16 @@ def complex_mm(input_tensor, other_tensor): Parameters ---------- - input_tensor : torch.Tensor - other_tensor : torch.Tensor + input_tensor: torch.Tensor + Input 2D tensor. + other_tensor: torch.Tensor + Other 2D tensor. Returns ------- - torch.Tensor + out: torch.Tensor + Complex-multiplied 2D output tensor. + """ return _complex_matrix_multiplication(input_tensor, other_tensor, torch.mm) @@ -427,11 +456,16 @@ def complex_bmm(input_tensor, other_tensor): Parameters ---------- - input_tensor : torch.Tensor - other_tensor : torch.Tensor + input_tensor: torch.Tensor + Input tensor. + other_tensor: torch.Tensor + Other tensor. + Returns ------- - torch.Tensor + out: torch.Tensor + Batch complex-multiplied output tensor. + """ return _complex_matrix_multiplication(input_tensor, other_tensor, torch.bmm) @@ -443,11 +477,12 @@ def conjugate(data: torch.Tensor) -> torch.Tensor: Parameters ---------- - data : torch.Tensor + data: torch.Tensor Returns ------- - torch.Tensor + conjugate_tensor: torch.Tensor + """ assert_complex(data, complex_last=True) data = data.clone() # Clone is required as the data in the next line is changed in-place. @@ -467,19 +502,20 @@ def apply_mask( Parameters ---------- - kspace : torch.Tensor + kspace: torch.Tensor k-space as a complex-valued tensor. - mask_func : callable or torch.tensor + mask_func: callable or torch.tensor Masking function, taking a shape and returning a mask with this shape or can be broadcast as such Can also be a sampling mask. - seed : int + seed: int Seed for the random number generator - return_mask : bool + return_mask: bool If true, mask will be returned Returns ------- - masked data (torch.Tensor), mask (torch.Tensor) + masked data, mask: (torch.Tensor, torch.Tensor) + """ # TODO: Split the function to apply_mask_func and apply_mask @@ -506,12 +542,14 @@ def tensor_to_complex_numpy(data: torch.Tensor) -> np.ndarray: Parameters ---------- - data : torch.Tensor + data: torch.Tensor Input data Returns ------- - Complex valued np.ndarray + out: np.array + Complex valued np.ndarray + """ assert_complex(data) data = data.detach().cpu().numpy() @@ -519,25 +557,25 @@ def tensor_to_complex_numpy(data: torch.Tensor) -> np.ndarray: def root_sum_of_squares(data: torch.Tensor, dim: int = 0, complex_dim: int = -1) -> torch.Tensor: - """ + r""" Compute the root sum of squares (RSS) transform along a given dimension of the input tensor. .. math:: - x_{\textrm{rss}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2} + x_{\textrm{RSS}} = \sqrt{\sum_{i \in \textrm{coil}} |x_i|^2} Parameters ---------- - data : torch.Tensor + data: torch.Tensor Input tensor - - dim : int + dim: int Coil dimension. Default is 0 as the first dimension is always the coil dimension. - - complex_dim : int + complex_dim: int Complex channel dimension. Default is -1. If data not complex this is ignored. + Returns ------- - torch.Tensor : RSS of the input tensor. + torch.Tensor: RSS of the input tensor. + """ if is_complex_data(data): return torch.sqrt((data ** 2).sum(complex_dim).sum(dim)) @@ -551,13 +589,13 @@ def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor: Parameters ---------- - data : torch.Tensor - shape : Tuple[int, int] + data: torch.Tensor + shape: Tuple[int, int] The output shape, should be smaller than the corresponding data dimensions. Returns ------- - torch.Tensor : The center cropped data. + torch.Tensor: The center cropped data. """ # TODO: Make dimension independent. if not (0 < shape[0] <= data.shape[-2]) or not (0 < shape[1] <= data.shape[-1]): @@ -577,15 +615,15 @@ def complex_center_crop(data_list, shape, offset=1, contiguous=False): Parameters ---------- - data_list : List[torch.Tensor] or torch.Tensor + data_list: List[torch.Tensor] or torch.Tensor The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied along dimensions didx and didx+1 and the last dimensions should have a size of 2. - shape : Tuple[int, int] + shape: Tuple[int, int] The output shape. The shape should be smaller than the corresponding dimensions of data. If one value is None, this is filled in by the image shape. - offset : int + offset: int Starting dimension for cropping. - contiguous : bool + contiguous: bool Return as a contiguous array. Useful for fast reshaping or viewing. Returns @@ -635,18 +673,18 @@ def complex_random_crop( Parameters ---------- - data_list : Union[List[torch.Tensor], torch.Tensor] + data_list: Union[List[torch.Tensor], torch.Tensor] The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied along dimensions -3 and -2 and the last dimensions should have a size of 2. - crop_shape : Tuple[int, ...] + crop_shape: Tuple[int, ...] The output shape. The shape should be smaller than the corresponding dimensions of data. - offset : int + offset: int Starting dimension for cropping. - contiguous : bool + contiguous: bool Return as a contiguous array. Useful for fast reshaping or viewing. - sampler : str + sampler: str Select the random indices from either a `uniform` or `gaussian` distribution (around the center) - sigma : float or list of float or None + sigma: float or list of float or None Standard variance of the gaussian when sampler is `gaussian`. If not set will take 1/3th of image shape Returns @@ -717,17 +755,18 @@ def reduce_operator( sensitivity_map: torch.Tensor, dim: int = 0, ) -> torch.Tensor: - """ - Given zero-filled reconstructions from multiple coils :math: \{x_i\}_{i=1}^{N_c} and coil sensitivity maps - :math: \{S_i\}_{i=1}^{N_c} it returns - .. math:: - R(x_1, .., x_{N_c}, S_1, .., S_{N_c}) = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i. + r""" + Given zero-filled reconstructions from multiple coils :math:`\{x_i\}_{i=1}^{N_c}` and + coil sensitivity maps :math:`\{S_i\}_{i=1}^{N_c}` it returns: - From paper End-to-End Variational Networks for Accelerated MRI Reconstruction. + .. math:: + R(x_{1}, .., x_{N_c}, S_1, .., S_{N_c}) = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i. + + Adapted from [1]_. Parameters ---------- - coil_data : torch.Tensor + coil_data: torch.Tensor Zero-filled reconstructions from coils. Should be a complex tensor (with complex dim of size 2). sensitivity_map: torch.Tensor Coil sensitivity maps. Should be complex tensor (with complex dim of size 2). @@ -738,6 +777,12 @@ def reduce_operator( ------- torch.Tensor: Combined individual coil images. + + References + ---------- + + .. [1] Sriram, Anuroop, et al. “End-to-End Variational Networks for Accelerated MRI Reconstruction.” ArXiv:2004.06688 [Cs, Eess], Apr. 2020. arXiv.org, http://arxiv.org/abs/2004.06688. + """ assert_complex(coil_data, complex_last=True) @@ -751,16 +796,17 @@ def expand_operator( sensitivity_map: torch.Tensor, dim: int = 0, ) -> torch.Tensor: - """ - Given a reconstructed image x and coil sensitivity maps :math: \{S_i\}_{i=1}^{N_c}, it returns + r""" + Given a reconstructed image :math:`x` and coil sensitivity maps :math:`\{S_i\}_{i=1}^{N_c}`, it returns + .. math:: - \Epsilon(x) = (S_1 \times x, .., S_{N_c} \times x) = (x_1, .., x_{N_c}). + E(x) = (S_1 \times x, .., S_{N_c} \times x) = (x_1, .., x_{N_c}). - From paper End-to-End Variational Networks for Accelerated MRI Reconstruction. + Adapted from [1]_. Parameters ---------- - data : torch.Tensor + data: torch.Tensor Image data. Should be a complex tensor (with complex dim of size 2). sensitivity_map: torch.Tensor Coil sensitivity maps. Should be complex tensor (with complex dim of size 2). @@ -771,6 +817,12 @@ def expand_operator( ------- torch.Tensor: Zero-filled reconstructions from each coil. + + References + ---------- + + .. [1] Sriram, Anuroop, et al. “End-to-End Variational Networks for Accelerated MRI Reconstruction.” ArXiv:2004.06688 [Cs, Eess], Apr. 2020. arXiv.org, http://arxiv.org/abs/2004.06688. + """ assert_complex(data, complex_last=True) diff --git a/direct/engine.py b/direct/engine.py index aaf3c206..b3783aab 100644 --- a/direct/engine.py +++ b/direct/engine.py @@ -658,7 +658,7 @@ def view_as_complex(data): Parameters ---------- - data : torch.Tensor + data: torch.Tensor Tensor with non-complex torch.dtype and final axis is complex (shape 2). Returns @@ -681,7 +681,7 @@ def view_as_real(data): Parameters ---------- - data : torch.Tensor + data: torch.Tensor Tensor with complex torch.dtype Returns diff --git a/direct/environment.py b/direct/environment.py index 60c81966..f656d4eb 100644 --- a/direct/environment.py +++ b/direct/environment.py @@ -36,7 +36,7 @@ def load_model_config_from_name(model_name): Parameters ---------- - model_name : path to model relative to direct.nn + model_name: path to model relative to direct.nn Returns ------- @@ -86,18 +86,18 @@ def setup_logging(machine_rank, output_directory, run_name, cfg_filename, cfg, d filename=log_file, log_level=("INFO" if not debug else "DEBUG"), ) - logger.info(f"Machine rank: {machine_rank}.") - logger.info(f"Local rank: {communication.get_local_rank()}.") - logger.info(f"Logging: {log_file}.") - logger.info(f"Saving to: {output_directory}.") - logger.info(f"Run name: {run_name}.") - logger.info(f"Config file: {cfg_filename}.") - logger.info(f"CUDA {torch.version.cuda} - cuDNN {torch.backends.cudnn.version()}.") - logger.info(f"Environment information: {collect_env.get_pretty_env_info()}.") - logger.info(f"DIRECT version: {direct.__version__}.") + logger.info("Machine rank: %s", machine_rank) + logger.info("Local rank: %s", communication.get_local_rank()) + logger.info("Logging: %s", log_file) + logger.info("Saving to: %s", output_directory) + logger.info("Run name: %s", run_name) + logger.info("Config file: %s", cfg_filename) + logger.info("CUDA %s - cuDNN %s", torch.version.cuda, torch.backends.cudnn.version()) + logger.info("Environment information: %s", collect_env.get_pretty_env_info()) + logger.info("DIRECT version: %s", direct.__version__) git_hash = direct.utils.git_hash() - logger.info(f"Git hash: {git_hash if git_hash else 'N/A'}.") - logger.info(f"Configuration: {OmegaConf.to_yaml(cfg)}.") + logger.info("Git hash: %s", git_hash if git_hash else "N/A") + logger.info("Configuration: %s", OmegaConf.to_yaml(cfg)) def load_models_into_environment_config(cfg_from_file): @@ -310,9 +310,9 @@ def setup_training_environment( ) # Write config file to experiment directory. config_file_in_project_folder = env.experiment_dir / "config.yaml" - logger.info(f"Writing configuration file to: {config_file_in_project_folder}.") + logger.info("Writing configuration file to: %s", config_file_in_project_folder) if communication.is_main_process(): - with open(config_file_in_project_folder, "w") as f: + with open(config_file_in_project_folder, "w", encoding="utf-8") as f: f.write(OmegaConf.to_yaml(env.cfg)) communication.synchronize() diff --git a/direct/exceptions.py b/direct/exceptions.py index 7132cda6..b55dc871 100644 --- a/direct/exceptions.py +++ b/direct/exceptions.py @@ -16,7 +16,7 @@ def __init__(self, signal_id: int, signal_name: str): """ Parameters ---------- - signal_id : str + signal_id: str signal_name: str """ super().__init__() diff --git a/direct/functionals/psnr.py b/direct/functionals/psnr.py index fcd76ecd..db3f07ea 100644 --- a/direct/functionals/psnr.py +++ b/direct/functionals/psnr.py @@ -12,9 +12,9 @@ def batch_psnr(input, target, reduction="mean"): Parameters ---------- - input : torch.Tensor - target : torch.Tensor - reduction : str + input: torch.Tensor + target: torch.Tensor + reduction: str Returns ------- diff --git a/direct/functionals/ssim.py b/direct/functionals/ssim.py index 32985070..21d6c8db 100644 --- a/direct/functionals/ssim.py +++ b/direct/functionals/ssim.py @@ -23,10 +23,14 @@ class SSIMLoss(nn.Module): def __init__(self, win_size=7, k1=0.01, k2=0.03): """ - Args: - win_size (int, default=7): Window size for SSIM calculation. - k1 (float, default=0.1): k1 parameter for SSIM calculation. - k2 (float, default=0.03): k2 parameter for SSIM calculation. + Parameters + ---------- + win_size: int + Window size for SSIM calculation. Default: 7. + k1: float + k1 parameter for SSIM calculation. Default: 0.1. + k2: float + k2 parameter for SSIM calculation. Default: 0.03. """ super().__init__() self.win_size = win_size diff --git a/direct/inference.py b/direct/inference.py index 6b257dd4..1229bff7 100644 --- a/direct/inference.py +++ b/direct/inference.py @@ -4,7 +4,7 @@ import logging import sys from functools import partial -from typing import Callable, Optional +from typing import Optional import torch @@ -39,20 +39,20 @@ def setup_inference_save_to_h5( Parameters ---------- - get_inference_settings : Callable - run_name : - data_root : - base_directory : - output_directory : - filenames_filter : - checkpoint : - device : - num_workers : - machine_rank : - cfg_file : - process_per_chunk : - mixed_precision : - debug : + get_inference_settings: Callable + run_name: + data_root: + base_directory: + output_directory: + filenames_filter: + checkpoint: + device: + num_workers: + machine_rank: + cfg_file: + process_per_chunk: + mixed_precision: + debug: Returns ------- diff --git a/direct/launch.py b/direct/launch.py index eb28dc8a..c7e384d7 100644 --- a/direct/launch.py +++ b/direct/launch.py @@ -53,20 +53,20 @@ def launch_distributed( Parameters ---------- - main_func : Callable + main_func: Callable A function that will be called by `main_func(*args)`. - num_gpus_per_machine : int + num_gpus_per_machine: int The number of GPUs per machine. num_machines : The number of machines. - machine_rank : int + machine_rank: int The rank of this machine (one per machine). - dist_url : str + dist_url: str url to connect to for distributed training, including protocol e.g. "tcp://127.0.0.1:8686". Can be set to auto to automatically select a free port on localhost - timeout : timedelta + timeout: timedelta Timeout of the distributed workers. - args : tuple + args: tuple arguments passed to main_func. """ @@ -161,13 +161,13 @@ def launch( Parameters ---------- - func : callable + func: callable function to launch - num_machines : int - num_gpus : int - machine_rank : int - dist_url : str - args : arguments to pass to func + num_machines: int + num_gpus: int + machine_rank: int + dist_url: str + args: arguments to pass to func Returns ------- diff --git a/direct/nn/conv/conv.py b/direct/nn/conv/conv.py index 2da353e7..6a1e98ee 100644 --- a/direct/nn/conv/conv.py +++ b/direct/nn/conv/conv.py @@ -15,17 +15,17 @@ def __init__(self, in_channels, out_channels, hidden_channels, n_convs=3, activa Parameters ---------- - in_channels : int + in_channels: int Number of input channels. - out_channels : int + out_channels: int Number of output channels. - hidden_channels : int + hidden_channels: int Number of hidden channels. - n_convs : int + n_convs: int Number of convolutional layers. - activation : nn.Module + activation: nn.Module Activation function. - batchnorm : bool + batchnorm: bool If True a batch normalization layer is applied after every convolution. """ super().__init__() diff --git a/direct/nn/crossdomain/crossdomain.py b/direct/nn/crossdomain/crossdomain.py index 18129e5a..38ae4412 100644 --- a/direct/nn/crossdomain/crossdomain.py +++ b/direct/nn/crossdomain/crossdomain.py @@ -30,23 +30,23 @@ def __init__( Parameters ---------- - forward_operator : Callable + forward_operator: Callable Forward Operator. - backward_operator : Callable + backward_operator: Callable Backward Operator. - image_model_list : nn.Module + image_model_list: nn.Module Image domain model list. - kspace_model_list : Optional[nn.Module] + kspace_model_list: Optional[nn.Module] K-space domain model list. If set to None, a correction step is applied. Default: None. - domain_sequence : str + domain_sequence: str Domain sequence containing only "K" (k-space domain) and/or "I" (image domain). Default: "KIKI". - image_buffer_size : int + image_buffer_size: int Image buffer size. Default: 1. - kspace_buffer_size : int + kspace_buffer_size: int K-space buffer size. Default: 1. - normalize_image : bool + normalize_image: bool If True, input is normalized. Default: False. - kwargs : dict + kwargs: dict Keyword Arguments. """ super().__init__() @@ -148,18 +148,18 @@ def forward( Parameters ---------- - masked_kspace : torch.Tensor + masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). - sampling_mask : torch.Tensor + sampling_mask: torch.Tensor Sampling mask of shape (N, 1, height, width, 1). - sensitivity_map : torch.Tensor + sensitivity_map: torch.Tensor Sensitivity map of shape (N, coil, height, width, complex=2). - scaling_factor : Optional[torch.Tensor] + scaling_factor: Optional[torch.Tensor] Scaling factor of shape (N,). If None, no scaling is applied. Default: None. Returns ------- - out_image : torch.Tensor + out_image: torch.Tensor Output image of shape (N, height, width, complex=2). """ input_image = self._backward_operator(masked_kspace, sampling_mask, sensitivity_map) diff --git a/direct/nn/crossdomain/multicoil.py b/direct/nn/crossdomain/multicoil.py index f8e0fb37..ed543d8c 100644 --- a/direct/nn/crossdomain/multicoil.py +++ b/direct/nn/crossdomain/multicoil.py @@ -17,11 +17,11 @@ def __init__(self, model: nn.Module, coil_dim: int = 1, coil_to_batch: bool = Fa Parameters ---------- - model : nn.Module + model: nn.Module Any nn.Module that takes as input with 4D data (N, H, W, C). Typically a convolutional-like model. - coil_dim : int + coil_dim: int Coil dimension. Default: 1. - coil_to_batch : bool + coil_to_batch: bool If True batch and coil dimensions are merged when forwarded by the model and unmerged when outputted. Otherwise, input is forwarded to the model per coil. """ @@ -45,12 +45,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Parameters ---------- - x : torch.Tensor + x: torch.Tensor Multi-coil input of shape (N, coil, height, width, in_channels). Returns ------- - out : torch.Tensor + out: torch.Tensor Multi-coil output of shape (N, coil, height, width, out_channels). """ if self.coil_to_batch: diff --git a/direct/nn/didn/didn.py b/direct/nn/didn/didn.py index 0642fff3..af35cbf0 100644 --- a/direct/nn/didn/didn.py +++ b/direct/nn/didn/didn.py @@ -13,9 +13,7 @@ class Subpixel(nn.Module): References ---------- - .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” - 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops - (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. """ @@ -37,9 +35,7 @@ class ReconBlock(nn.Module): References ---------- - .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” - 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops - (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. """ @@ -75,9 +71,7 @@ class DUB(nn.Module): References ---------- - .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” - 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops - (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. """ @@ -179,9 +173,7 @@ class DIDN(nn.Module): References ---------- - .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” - 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops - (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. + .. [1] Yu, Songhyun, et al. “Deep Iterative Down-Up CNN for Image Denoising.” 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), 2019, pp. 2095–103. IEEE Xplore, https://doi.org/10.1109/CVPRW.2019.00262. """ @@ -198,17 +190,17 @@ def __init__( Parameters ---------- - in_channels : int + in_channels: int Number of input channels. - out_channels : int + out_channels: int Number of output channels. - hidden_channels : int + hidden_channels: int Number of hidden channels. First convolution out_channels. Default: 128. - num_dubs : int + num_dubs: int Number of DUB networks. Default: 6. - num_convs_recon : int + num_convs_recon: int Number of ReconBlock convolutions. Default: 9. - skip_connection : bool + skip_connection: bool Use skip connection. Default: False. """ super().__init__() @@ -263,14 +255,14 @@ def forward(self, x, channel_dim=1): Parameters ---------- - x : torch.Tensor + x: torch.Tensor Input tensor. - channel_dim : int + channel_dim: int Channel dimension. Default: 1. Returns ------- - out : torch.Tensor + out: torch.Tensor Output tensor. """ out = self.conv_in(x) diff --git a/direct/nn/jointicnet/jointicnet.py b/direct/nn/jointicnet/jointicnet.py index 143d0a3d..8c584cbf 100644 --- a/direct/nn/jointicnet/jointicnet.py +++ b/direct/nn/jointicnet/jointicnet.py @@ -12,8 +12,11 @@ class JointICNet(nn.Module): """ - Joint-ICNet implementation as in "Joint Deep Model-based MR Image and Coil Sensitivity Reconstruction Network - (Joint-ICNet) for Fast MRI" submitted to the fastmri challenge. + Joint Deep Model-Based MR Image and Coil Sensitivity Reconstruction Network (Joint-ICNet) implementation as presented in [1]_. + + References + ---------- + .. [1] Jun, Yohan, et al. “Joint Deep Model-Based MR Image and Coil Sensitivity Reconstruction Network (Joint-ICNet) for Fast MRI.” 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), IEEE, 2021, pp. 5266–75. DOI.org (Crossref), https://doi.org/10.1109/CVPR46437.2021.00523. """ @@ -29,13 +32,13 @@ def __init__( Parameters ---------- - forward_operator : Callable + forward_operator: Callable Forward Transform. - backward_operator : Callable + backward_operator: Callable Backward Transform. - num_iter : int + num_iter: int Number of unrolled iterations. Default: 10. - use_norm_unet : bool + use_norm_unet: bool If True, a Normalized U-Net is used. Default: False. kwargs: dict Image, k-space and sensitivity-map U-Net models keyword-arguments. @@ -138,16 +141,16 @@ def forward( Parameters ---------- - masked_kspace : torch.Tensor + masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). - sampling_mask : torch.Tensor + sampling_mask: torch.Tensor Sampling mask of shape (N, 1, height, width, 1). - sensitivity_map : torch.Tensor + sensitivity_map: torch.Tensor Sensitivity map of shape (N, coil, height, width, complex=2). Returns ------- - out_image : torch.Tensor + out_image: torch.Tensor Output image of shape (N, height, width, complex=2). """ diff --git a/direct/nn/jointicnet/jointicnet_engine.py b/direct/nn/jointicnet/jointicnet_engine.py index baa6e22b..f39a85f7 100644 --- a/direct/nn/jointicnet/jointicnet_engine.py +++ b/direct/nn/jointicnet/jointicnet_engine.py @@ -146,10 +146,12 @@ def l1_loss(source, reduction="mean", **data): """ Calculate L1 loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -161,10 +163,12 @@ def l2_loss(source, reduction="mean", **data): """ Calculate L2 loss (MSE) given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -176,10 +180,12 @@ def ssim_loss(source, reduction="mean", **data): """ Calculate SSIM loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -225,11 +231,11 @@ def evaluate( Parameters ---------- - data_loader : DataLoader - loss_fns : Dict[str, Callable], optional - regularizer_fns : Dict[str, Callable], optional - crop : str, optional - is_validation_process : bool + data_loader: DataLoader + loss_fns: Dict[str, Callable], optional + regularizer_fns: Dict[str, Callable], optional + crop: str, optional + is_validation_process: bool Returns ------- @@ -429,10 +435,12 @@ def cropper(self, source, target, resolution): """ 2D source/target cropper - Parameters: - ----------- - Source has shape (batch, height, width) - Target has shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, height, width) + target: torch.Tensor + Has shape (batch, height, width) """ diff --git a/direct/nn/kikinet/kikinet.py b/direct/nn/kikinet/kikinet.py index 6198dbcc..5abf70c6 100644 --- a/direct/nn/kikinet/kikinet.py +++ b/direct/nn/kikinet/kikinet.py @@ -21,9 +21,7 @@ class KIKINet(nn.Module): References ---------- - .. [1] Eo, Taejoon, et al. “KIKI-Net: Cross-Domain Convolutional Neural Networks for Reconstructing - Undersampled Magnetic Resonance Images.” Magnetic Resonance in Medicine, vol. 80, no. 5, Nov. 2018, - pp. 2188–201. PubMed, https://doi.org/10.1002/mrm.27201. + .. [1] Eo, Taejoon, et al. “KIKI-Net: Cross-Domain Convolutional Neural Networks for Reconstructing Undersampled Magnetic Resonance Images.” Magnetic Resonance in Medicine, vol. 80, no. 5, Nov. 2018, pp. 2188–201. PubMed, https://doi.org/10.1002/mrm.27201. """ @@ -41,19 +39,19 @@ def __init__( Parameters ---------- - forward_operator : Callable + forward_operator: Callable Forward Operator. - backward_operator : Callable + backward_operator: Callable Backward Operator. - image_model_architecture : str + image_model_architecture: str Image model architecture. Currently only implemented for MWCNN and (NORM)UNET. Default: 'MWCNN'. - kspace_model_architecture : str + kspace_model_architecture: str Kspace model architecture. Currently only implemented for CONV and DIDN and (NORM)UNET. Default: 'DIDN'. - num_iter : int + num_iter: int Number of unrolled iterations. - normalize : bool + normalize: bool If true, input is normalised based on input scaling_factor. - kwargs : dict + kwargs: dict Keyword arguments for model architectures. """ super().__init__() @@ -135,18 +133,18 @@ def forward( Parameters ---------- - masked_kspace : torch.Tensor + masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). - sampling_mask : torch.Tensor + sampling_mask: torch.Tensor Sampling mask of shape (N, 1, height, width, 1). - sensitivity_map : torch.Tensor + sensitivity_map: torch.Tensor Sensitivity map of shape (N, coil, height, width, complex=2). - scaling_factor : Optional[torch.Tensor] + scaling_factor: Optional[torch.Tensor] Scaling factor of shape (N,). If None, no scaling is applied. Default: None. Returns ------- - out_image : torch.Tensor + out_image: torch.Tensor Output image of shape (N, height, width, complex=2). """ diff --git a/direct/nn/kikinet/kikinet_engine.py b/direct/nn/kikinet/kikinet_engine.py index 7832cac6..60b96b9e 100644 --- a/direct/nn/kikinet/kikinet_engine.py +++ b/direct/nn/kikinet/kikinet_engine.py @@ -156,10 +156,12 @@ def l1_loss(source, reduction="mean", **data): """ Calculate L1 loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -171,10 +173,12 @@ def l2_loss(source, reduction="mean", **data): """ Calculate L2 loss (MSE) given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -186,10 +190,12 @@ def ssim_loss(source, reduction="mean", **data): """ Calculate SSIM loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -235,11 +241,11 @@ def evaluate( Parameters ---------- - data_loader : DataLoader - loss_fns : Dict[str, Callable], optional - regularizer_fns : Dict[str, Callable], optional - crop : str, optional - is_validation_process : bool + data_loader: DataLoader + loss_fns: Dict[str, Callable], optional + regularizer_fns: Dict[str, Callable], optional + crop: str, optional + is_validation_process: bool Returns ------- @@ -439,10 +445,12 @@ def cropper(self, source, target, resolution): """ 2D source/target cropper - Parameters: - ----------- - Source has shape (batch, height, width) - Target has shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, height, width) + target: torch.Tensor + Has shape (batch, height, width) """ diff --git a/direct/nn/lpd/lpd.py b/direct/nn/lpd/lpd.py index b4d33a25..d6bf227d 100644 --- a/direct/nn/lpd/lpd.py +++ b/direct/nn/lpd/lpd.py @@ -90,8 +90,7 @@ class LPDNet(nn.Module): References ---------- - .. [1] Adler, Jonas, and Ozan Öktem. “Learned Primal-Dual Reconstruction.” IEEE Transactions on Medical Imaging, - vol. 37, no. 6, June 2018, pp. 1322–32. arXiv.org, https://doi.org/10.1109/TMI.2018.2799231. + .. [1] Adler, Jonas, and Ozan Öktem. “Learned Primal-Dual Reconstruction.” IEEE Transactions on Medical Imaging, vol. 37, no. 6, June 2018, pp. 1322–32. arXiv.org, https://doi.org/10.1109/TMI.2018.2799231. """ @@ -110,21 +109,21 @@ def __init__( Parameters ---------- - forward_operator : Callable + forward_operator: Callable Forward Operator. - backward_operator : Callable + backward_operator: Callable Backward Operator. - num_iter : int + num_iter: int Number of unrolled iterations. - num_primal : int + num_primal: int Number of primal networks. - num_dual : int + num_dual: int Number of dual networks. - primal_model_architecture : str + primal_model_architecture: str Primal model architecture. Currently only implemented for MWCNN and (NORM)UNET. Default: 'MWCNN'. - dual_model_architecture : str + dual_model_architecture: str Dual model architecture. Currently only implemented for CONV and DIDN and (NORM)UNET. Default: 'DIDN'. - kwargs : dict + kwargs: dict Keyword arguments for model architectures. """ super().__init__() @@ -236,16 +235,16 @@ def forward( Parameters ---------- - masked_kspace : torch.Tensor + masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). - sensitivity_map : torch.Tensor + sensitivity_map: torch.Tensor Sensitivity map of shape (N, coil, height, width, complex=2). - sampling_mask : torch.Tensor + sampling_mask: torch.Tensor Sampling mask of shape (N, 1, height, width, 1). Returns ------- - output : torch.Tensor + output: torch.Tensor Output image of shape (N, height, width, complex=2). """ input_image = self._backward_operator(masked_kspace, sampling_mask, sensitivity_map) diff --git a/direct/nn/lpd/lpd_engine.py b/direct/nn/lpd/lpd_engine.py index 56045308..fab14096 100644 --- a/direct/nn/lpd/lpd_engine.py +++ b/direct/nn/lpd/lpd_engine.py @@ -155,10 +155,12 @@ def l1_loss(source, reduction="mean", **data): """ Calculate L1 loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -170,10 +172,12 @@ def l2_loss(source, reduction="mean", **data): """ Calculate L2 loss (MSE) given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -185,10 +189,12 @@ def ssim_loss(source, reduction="mean", **data): """ Calculate SSIM loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -234,11 +240,11 @@ def evaluate( Parameters ---------- - data_loader : DataLoader - loss_fns : Dict[str, Callable], optional - regularizer_fns : Dict[str, Callable], optional - crop : str, optional - is_validation_process : bool + data_loader: DataLoader + loss_fns: Dict[str, Callable], optional + regularizer_fns: Dict[str, Callable], optional + crop: str, optional + is_validation_process: bool Returns ------- @@ -438,10 +444,12 @@ def cropper(self, source, target, resolution): """ 2D source/target cropper - Parameters: - ----------- - Source has shape (batch, height, width) - Target has shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, height, width) + target: torch.Tensor + Has shape (batch, height, width) """ diff --git a/direct/nn/mobilenet/mobilenet.py b/direct/nn/mobilenet/mobilenet.py index 44930e31..5cf5f991 100644 --- a/direct/nn/mobilenet/mobilenet.py +++ b/direct/nn/mobilenet/mobilenet.py @@ -3,7 +3,7 @@ # Taken and adapted from: https://raw.githubusercontent.com/pytorch/vision/master/torchvision/models/mobilenet.py -from typing import Any, Callable, Optional +from typing import Any, Callable from torch import nn @@ -105,19 +105,19 @@ def __init__( Parameters ---------- - num_channels : int + num_channels: int Number of channels. - num_classes : int + num_classes: int Number of classes. - width_mult : float + width_mult: float Width multiplier - adjusts number of channels in each layer by this amount. - inverted_residual_setting : Network structure - round_nearest : int + inverted_residual_setting: Network structure + round_nearest: int Round the number of channels in each layer to be a multiple of this number Set to 1 to turn off rounding - block : str + block: str Module specifying inverted residual building block for mobilenet. - norm_layer : str + norm_layer: str Module specifying the normalization layer to use. """ diff --git a/direct/nn/multidomainnet/config.py b/direct/nn/multidomainnet/config.py index c46ab526..4609aeb8 100644 --- a/direct/nn/multidomainnet/config.py +++ b/direct/nn/multidomainnet/config.py @@ -1,7 +1,6 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors from dataclasses import dataclass -from typing import Optional, Tuple from direct.config.defaults import ModelConfig diff --git a/direct/nn/multidomainnet/multidomain.py b/direct/nn/multidomainnet/multidomain.py index 51505285..b914efb3 100644 --- a/direct/nn/multidomainnet/multidomain.py +++ b/direct/nn/multidomainnet/multidomain.py @@ -106,11 +106,11 @@ def __init__( Parameters ---------- - in_channels : int + in_channels: int Number of input channels. - out_channels : int + out_channels: int Number of output channels. - dropout_probability : float + dropout_probability: float Dropout probability. """ super().__init__() @@ -134,18 +134,18 @@ def __init__( nn.Dropout2d(dropout_probability), ) - def forward(self, input: torch.Tensor): + def forward(self, _input: torch.Tensor): """ Parameters ---------- - input : torch.Tensor + _input: torch.Tensor Returns ------- torch.Tensor """ - return self.layers(input) + return self.layers(_input) def __repr__(self): return ( @@ -164,9 +164,9 @@ def __init__(self, forward_operator, backward_operator, in_channels: int, out_ch """ Parameters ---------- - in_channels : int + in_channels: int Number of input channels. - out_channels : int + out_channels: int Number of output channels. """ super().__init__() @@ -185,7 +185,7 @@ def forward(self, input: torch.Tensor): Parameters ---------- - input : torch.Tensor + input: torch.Tensor Returns ------- @@ -216,19 +216,19 @@ def __init__( Parameters ---------- - forward_operator : Callable + forward_operator: Callable Forward Operator. - backward_operator : Callable + backward_operator: Callable Backward Operator. - in_channels : int + in_channels: int Number of input channels to the u-net. - out_channels : int + out_channels: int Number of output channels to the u-net. - num_filters : int + num_filters: int Number of output channels of the first convolutional layer. - num_pool_layers : int + num_pool_layers: int Number of down-sampling and up-sampling layers (depth). - dropout_probability : float + dropout_probability: float Dropout probability. """ super().__init__() @@ -272,7 +272,7 @@ def forward(self, input: torch.Tensor): Parameters ---------- - input : torch.Tensor + input: torch.Tensor Returns ------- diff --git a/direct/nn/multidomainnet/multidomainnet.py b/direct/nn/multidomainnet/multidomainnet.py index 1ebec51c..03f8ddb1 100644 --- a/direct/nn/multidomainnet/multidomainnet.py +++ b/direct/nn/multidomainnet/multidomainnet.py @@ -11,14 +11,15 @@ class StandardizationLayer(nn.Module): - """ + r""" Multi-channel data standardization method. Inspired by AIRS model submission to the Fast MRI 2020 challenge. - Given individual coil images :math: {x_i}_{i=1}^{N_c} and sensitivity coil maps :math: {S_i}_{i=1}^{N_c} + Given individual coil images :math:`\{x_i\}_{i=1}^{N_c}` and sensitivity coil maps :math:`\{S_i\}_{i=1}^{N_c}` it returns + .. math:: - {xres_i}_{i=1}^{N_c}, - where :math: xres_i = [x_{sense}, xi - S_i \times x_{sense}] - and :math: x_{sense} = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i. + [x_{\text{sense}}, {x_{\text{res}}}_1, ..., {x_{\text{res}}}_{N_c}] + + where :math:`{x_{\text{res}}}_i = xi - S_i \times x_{\text{sense}}` and :math:`x_{\text{sense}} = \sum_{i=1}^{N_c} {S_i}^{*} \times x_i`. """ @@ -63,17 +64,17 @@ def __init__( Parameters ---------- - forward_operator : Callable + forward_operator: Callable Forward Operator. - backward_operator : Callable + backward_operator: Callable Backward Operator. - standardization : bool + standardization: bool If True standardization is used. Default: True. - num_filters : int + num_filters: int Number of filters for the MultiDomainUnet module. Default: 16. - num_pool_layers : int + num_pool_layers: int Number of pooling layers for the MultiDomainUnet module. Default: 4. - dropout_probability : float + dropout_probability: float Dropout probability for the MultiDomainUnet module. Default: 0.0. """ super().__init__() @@ -112,14 +113,14 @@ def forward(self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor) -> Parameters ---------- - masked_kspace : torch.Tensor + masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). - sensitivity_map : torch.Tensor + sensitivity_map: torch.Tensor Sensitivity map of shape (N, coil, height, width, complex=2). Returns ------- - output_image : torch.Tensor + output_image: torch.Tensor Multi-coil output image of shape (N, coil, height, width, complex=2). """ input_image = self.backward_operator(masked_kspace, dim=self._spatial_dims) diff --git a/direct/nn/multidomainnet/multidomainnet_engine.py b/direct/nn/multidomainnet/multidomainnet_engine.py index b00b69d1..caf115d8 100644 --- a/direct/nn/multidomainnet/multidomainnet_engine.py +++ b/direct/nn/multidomainnet/multidomainnet_engine.py @@ -158,10 +158,12 @@ def l1_loss(source, reduction="mean", **data): """ Calculate L1 loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -173,10 +175,12 @@ def l2_loss(source, reduction="mean", **data): """ Calculate L2 loss (MSE) given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -188,10 +192,12 @@ def ssim_loss(source, reduction="mean", **data): """ Calculate SSIM loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -237,11 +243,11 @@ def evaluate( Parameters ---------- - data_loader : DataLoader - loss_fns : Dict[str, Callable], optional - regularizer_fns : Dict[str, Callable], optional - crop : str, optional - is_validation_process : bool + data_loader: DataLoader + loss_fns: Dict[str, Callable], optional + regularizer_fns: Dict[str, Callable], optional + crop: str, optional + is_validation_process: bool Returns ------- @@ -441,10 +447,12 @@ def cropper(self, source, target, resolution): """ 2D source/target cropper - Parameters: - ----------- - Source has shape (batch, height, width) - Target has shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, height, width) + target: torch.Tensor + Has shape (batch, height, width) """ diff --git a/direct/nn/mwcnn/mwcnn.py b/direct/nn/mwcnn/mwcnn.py index cd800a93..1852be59 100644 --- a/direct/nn/mwcnn/mwcnn.py +++ b/direct/nn/mwcnn/mwcnn.py @@ -16,8 +16,7 @@ class DWT(nn.Module): References ---------- - .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. - arXiv.org, http://arxiv.org/abs/1805.07071. + .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071. """ @@ -47,8 +46,7 @@ class IWT(nn.Module): References ---------- - .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. - arXiv.org, http://arxiv.org/abs/1805.07071. + .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071. """ @@ -84,8 +82,7 @@ class ConvBlock(nn.Module): References ---------- - .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. - arXiv.org, http://arxiv.org/abs/1805.07071. + .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071. """ @@ -130,8 +127,7 @@ class DilatedConvBlock(nn.Module): References ---------- - .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. - arXiv.org, http://arxiv.org/abs/1805.07071. + .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071. """ @@ -192,8 +188,7 @@ class MWCNN(nn.Module): References ---------- - .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. - arXiv.org, http://arxiv.org/abs/1805.07071. + .. [1] Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071. """ @@ -210,17 +205,17 @@ def __init__( Parameters ---------- - input_channels : int + input_channels: int Input channels dimension. - first_conv_hidden_channels : int + first_conv_hidden_channels: int First convolution output channels dimension. - num_scales : int + num_scales: int Number of scales. Default: 4. - bias : bool + bias: bool Convolution bias. If True, adds a learnable bias to the output. Default: True. - batchnorm : bool + batchnorm: bool If True, a batchnorm layer is added after each convolution. Default: False. - activation : nn.Module + activation: nn.Module Activation function applied after each convolution. Default: nn.ReLU(). """ super().__init__() @@ -329,9 +324,9 @@ def forward(self, input: torch.Tensor, res: bool = False) -> torch.Tensor: Parameters ---------- - input : torch.Tensor + input: torch.Tensor Input tensor. - res : bool + res: bool If True, residual connection is applied to the output. Default: False. Returns diff --git a/direct/nn/recurrent/recurrent.py b/direct/nn/recurrent/recurrent.py index eedf994d..73f1f27f 100644 --- a/direct/nn/recurrent/recurrent.py +++ b/direct/nn/recurrent/recurrent.py @@ -29,23 +29,23 @@ def __init__( Parameters ---------- - in_channels : int + in_channels: int Number of input channels. - hidden_channels : int + hidden_channels: int Number of hidden channels. - out_channels : Optional[int] + out_channels: Optional[int] Number of output channels. If None, same as in_channels. Default: None. - num_layers : int + num_layers: int Number of layers. Default: 2. - gru_kernel_size : int + gru_kernel_size: int Size of the GRU kernel. Default: 1. - orthogonal_initialization : bool + orthogonal_initialization: bool Orthogonal initialization is used if set to True. Default: True. - instance_norm : bool + instance_norm: bool Instance norm is used if set to True. Default: False. - dense_connect : int + dense_connect: int Number of dense connections. - replication_padding : bool + replication_padding: bool If set to true replication padding is applied. """ super().__init__() @@ -118,9 +118,9 @@ def forward( Parameters ---------- - cell_input : torch.Tensor + cell_input: torch.Tensor Reconstruction input - previous_state : torch.Tensor + previous_state: torch.Tensor Tensor of previous states. Returns diff --git a/direct/nn/recurrentvarnet/recurrentvarnet.py b/direct/nn/recurrentvarnet/recurrentvarnet.py index 2c4eaead..f65a408b 100644 --- a/direct/nn/recurrentvarnet/recurrentvarnet.py +++ b/direct/nn/recurrentvarnet/recurrentvarnet.py @@ -15,15 +15,13 @@ class RecurrentInit(nn.Module): """ Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in [1]_. - The RSI module learns to initialize the recurrent hidden state h_0, input of the first RecurrentVarNet + The RSI module learns to initialize the recurrent hidden state :math:`h_0`, input of the first RecurrentVarNet Block of the RecurrentVarNet. References ---------- - .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver - Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. - arXiv.org, http://arxiv.org/abs/2111.09639. + .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639. """ @@ -40,17 +38,17 @@ def __init__( Parameters ---------- - in_channels : int + in_channels: int Input channels. - out_channels : int + out_channels: int Number of hidden channels of the recurrent unit of RecurrentVarNet Block. - channels : tuple + channels: tuple Channels :math:`n_d` in the convolutional layers of initializer. dilations: tuple Dilations :math:`p` of the convolutional layers of the initializer. - depth : int + depth: int RecurrentVarNet Block number of layers :math:`n_l`. - multiscale_depth : 1 + multiscale_depth: 1 Number of feature layers to aggregate for the output, if 1, multi-scale context aggregation is disabled. """ @@ -97,9 +95,7 @@ class RecurrentVarNet(nn.Module): References ---------- - .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver - Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. - arXiv.org, http://arxiv.org/abs/2111.09639. + .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639. """ @@ -123,30 +119,30 @@ def __init__( Parameters ---------- - forward_operator : Callable + forward_operator: Callable Forward Operator. - backward_operator : Callable + backward_operator: Callable Backward Operator. - num_steps : int + num_steps: int Number of iterations :math:`T`. - in_channels : int + in_channels: int Input channel number. Default is 2 for complex data. - recurrent_hidden_channels : int + recurrent_hidden_channels: int Hidden channels number for the recurrent unit of the RecurrentVarNet Blocks. Default: 64. - recurrent_num_layers : int + recurrent_num_layers: int Number of layers for the recurrent unit of the RecurrentVarNet Block (:math:`n_l`). Default: 4. - no_parameter_sharing : bool + no_parameter_sharing: bool If False, the same RecurrentVarNet Block is used for all num_steps. Default: True. - learned_initializer : bool + learned_initializer: bool If True an RSI module is used. Default: False. - initializer_initialization : str, Optional + initializer_initialization: str, Optional Type of initialization for the RSI module. Can be either 'sense', 'zero-filled' or 'input-image'. Default: None. - initializer_channels : tuple + initializer_channels: tuple Channels :math:`n_d` in the convolutional layers of the RSI module. Default: (32, 32, 64, 64). - initializer_dilations : tuple + initializer_dilations: tuple Dilations :math:`p` of the convolutional layers of the RSI module. Default: (1, 1, 2, 4). - initializer_multiscale : int + initializer_multiscale: int RSI module number of feature layers to aggregate for the output, if 1, multi-scale context aggregation is disabled. Default: 1. @@ -224,11 +220,11 @@ def forward( """ Parameters ---------- - masked_kspace : torch.Tensor + masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). - sampling_mask : torch.Tensor + sampling_mask: torch.Tensor Sampling mask of shape (N, 1, height, width, 1). - sensitivity_map : torch.Tensor + sensitivity_map: torch.Tensor Coil sensitivities of shape (N, coil, height, width, complex=2). Returns @@ -280,15 +276,13 @@ def forward( class RecurrentVarNetBlock(nn.Module): - """ - Recurrent Variational Network Block as presented in [1]_. + r""" + Recurrent Variational Network Block :math:`\mathcal{H}_{\theta_{t}}` as presented in [1]_. References ---------- - .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver - Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. - arXiv.org, http://arxiv.org/abs/2111.09639. + .. [1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639. """ @@ -301,8 +295,8 @@ def __init__( num_layers: int = 4, ): """ - Parameters: - ----------- + Parameters + ---------- forward_operator: Callable Forward Fourier Transform. backward_operator: Callable @@ -343,11 +337,11 @@ def forward( ---------- current_kspace: torch.Tensor Current k-space prediction of shape (N, coil, height, width, complex=2). - masked_kspace : torch.Tensor + masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). - sampling_mask : torch.Tensor + sampling_mask: torch.Tensor Sampling mask of shape (N, 1, height, width, 1). - sensitivity_map : torch.Tensor + sensitivity_map: torch.Tensor Coil sensitivities of shape (N, coil, height, width, complex=2). hidden_state: torch.Tensor or None ConvGRU hidden state of shape (N, hidden_channels, height, width, num_layers) if not None. Optional. diff --git a/direct/nn/recurrentvarnet/recurrentvarnet_engine.py b/direct/nn/recurrentvarnet/recurrentvarnet_engine.py index 57ea3668..0fd2e487 100644 --- a/direct/nn/recurrentvarnet/recurrentvarnet_engine.py +++ b/direct/nn/recurrentvarnet/recurrentvarnet_engine.py @@ -159,10 +159,12 @@ def l1_loss(source, reduction="mean", **data): """ Calculate L1 loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + Data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -174,10 +176,12 @@ def l2_loss(source, reduction="mean", **data): """ Calculate L2 loss (MSE) given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + Data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -189,10 +193,12 @@ def ssim_loss(source, reduction="mean", **data): """ Calculate SSIM loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + Data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -238,11 +244,11 @@ def evaluate( Parameters ---------- - data_loader : DataLoader - loss_fns : Dict[str, Callable], optional - regularizer_fns : Dict[str, Callable], optional - crop : str, optional - is_validation_process : bool + data_loader: DataLoader + loss_fns: Dict[str, Callable], optional + regularizer_fns: Dict[str, Callable], optional + crop: str, optional + is_validation_process: bool Returns ------- @@ -445,10 +451,12 @@ def cropper(self, source, target, resolution): """ 2D source/target cropper - Parameters: - ----------- - Source has shape (batch, height, width) - Target has shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, height, width) + target: torch.Tensor + Has shape (batch, height, width) """ diff --git a/direct/nn/rim/rim.py b/direct/nn/rim/rim.py index a6edd16a..9d9ef3b1 100644 --- a/direct/nn/rim/rim.py +++ b/direct/nn/rim/rim.py @@ -15,12 +15,13 @@ class MRILogLikelihood(nn.Module): - """ + r""" Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils. + .. math:: - \frac{1}{\sigma^2} \sum_{i}^{\text{num coils}} - {S}_i^\{text{H}} \mathcal{F}^{-1} P^T (P \mathcal{F} S_i x_\tau - y_\tau) - for each time step :math: \tau. + \frac{1}{\sigma^2} \sum_{i}^{N_c} {S}_i^{\text{H}} \mathcal{F}^{-1} P^{*} (P \mathcal{F} S_i x_{\tau} - y_{\tau}) + + for each time step :math:`\tau`. """ def __init__( @@ -48,15 +49,15 @@ def forward( Parameters ---------- - input_image : torch.tensor + input_image: torch.tensor Initial or previous iteration of image with complex first of shape (N, complex, height, width). - masked_kspace : torch.tensor + masked_kspace: torch.tensor Masked k-space of shape (N, coil, height, width, complex). - sensitivity_map : torch.tensor + sensitivity_map: torch.tensor Sensitivity Map of shape (N, coil, height, width, complex). - sampling_mask : torch.tensor - loglikelihood_scaling : torch.tensor + sampling_mask: torch.tensor + loglikelihood_scaling: torch.tensor Multiplier for loglikelihood, for instance for the k-space noise, of shape (1,). Returns @@ -116,8 +117,7 @@ class RIMInit(nn.Module): References ---------- - .. [1] Yu, Fisher, and Vladlen Koltun. “Multi-Scale Context Aggregation by Dilated Convolutions.” - ArXiv:1511.07122 [Cs], Apr. 2016. arXiv.org, http://arxiv.org/abs/1511.07122. + .. [1] Yu, Fisher, and Vladlen Koltun. “Multi-Scale Context Aggregation by Dilated Convolutions.” ArXiv:1511.07122 [Cs], Apr. 2016. arXiv.org, http://arxiv.org/abs/1511.07122. """ def __init__( @@ -133,17 +133,17 @@ def __init__( Parameters ---------- - x_ch : int + x_ch: int Input channels. - out_ch : int + out_ch: int Number of hidden channels in the RIM. - channels : tuple + channels: tuple Channels in the convolutional layers of initializer. Typical it could be e.g. (32, 32, 64, 64). dilations: tuple Dilations of the convolutional layers of the initializer. Typically it could be e.g. (1, 1, 2, 4). - depth : int + depth: int RIM depth - multiscale_depth : 1 + multiscale_depth: 1 Number of feature layers to aggregate for the output, if 1, multi-scale context aggregation is disabled. """ @@ -190,8 +190,7 @@ class RIM(nn.Module): References ---------- - .. [1] Putzky, Patrick, and Max Welling. “Recurrent Inference Machines for Solving Inverse Problems.” - ArXiv:1706.04008 [Cs], June 2017. arXiv.org, http://arxiv.org/abs/1706.04008. + .. [1] Putzky, Patrick, and Max Welling. “Recurrent Inference Machines for Solving Inverse Problems.” ArXiv:1706.04008 [Cs], June 2017. arXiv.org, http://arxiv.org/abs/1706.04008. """ @@ -305,16 +304,16 @@ def forward( """ Parameters ---------- - input_image : torch.Tensor + input_image: torch.Tensor Initial or intermediate guess of input. Has shape (N, height, width, complex=2). - masked_kspace : torch.Tensor + masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). - sensitivity_map : torch.Tensor + sensitivity_map: torch.Tensor Sensitivity map of shape (N, coil, height, width, complex=2). - sampling_mask : torch.Tensor + sampling_mask: torch.Tensor Sampling mask of shape (N, 1, height, width, 1). - previous_state : torch.Tensor - loglikelihood_scaling : torch.Tensor + previous_state: torch.Tensor + loglikelihood_scaling: torch.Tensor Float tensor of shape (1,). Returns diff --git a/direct/nn/rim/rim_engine.py b/direct/nn/rim/rim_engine.py index 438da1ba..f8701471 100644 --- a/direct/nn/rim/rim_engine.py +++ b/direct/nn/rim/rim_engine.py @@ -202,10 +202,10 @@ def l1_loss(source, reduction="mean", **data): """ Calculate L1 loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -217,10 +217,10 @@ def ssim_loss(source, reduction="mean", **data): """ Calculate SSIM loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + Source: shape (batch, complex=2, height, width) + Data: Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -264,11 +264,11 @@ def evaluate( Parameters ---------- - data_loader : DataLoader - loss_fns : Dict[str, Callable], optional - regularizer_fns : Dict[str, Callable], optional - crop : str, optional - is_validation_process : bool + data_loader: DataLoader + loss_fns: Dict[str, Callable], optional + regularizer_fns: Dict[str, Callable], optional + crop: str, optional + is_validation_process: bool Returns ------- @@ -479,10 +479,12 @@ def cropper(self, source, target, resolution): """ 2D source/target cropper - Parameters: - ----------- - Source has shape (batch, complex=2, height, width) - Target has shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + target: torch.Tensor + Has shape (batch, height, width) """ source_abs = T.modulus(source) # shape (batch, height, width) @@ -509,86 +511,3 @@ def compute_model_per_coil(self, model_name, data): # output is of shape (batch, coil, complex=2, [slice], height, width) return output - - -class RIM3dEngine(RIMEngine): - """ - Recurrent Inference Machine Engine for 3D data. - """ - - def __init__( - self, - cfg: BaseConfig, - model: nn.Module, - device: int, - forward_operator: Optional[Callable] = None, - backward_operator: Optional[Callable] = None, - mixed_precision: bool = False, - **models: nn.Module, - ): - super().__init__( - cfg, - model, - device, - forward_operator=forward_operator, - backward_operator=backward_operator, - mixed_precision=mixed_precision, - **models, - ) - self._slice_dim = -3 - - def process_output(self, data, scaling_factors=None, resolution=None): - # Data has shape (batch, complex, slice, height, width) - # TODO(gy): verify shape - - self._slice_dim = -3 - center_slice = data.size(self._slice_dim) // 2 - - if scaling_factors is not None: - data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device) - - data = T.modulus_if_complex(data).select(self._slice_dim, center_slice) - - if len(data.shape) == 3: # (batch, height, width) - data = data.unsqueeze(1) # Added channel dimension. - - if resolution is not None: - data = T.center_crop(data, resolution).contiguous() - - return data - - def cropper(self, source, target, resolution=(320, 320)): - """ - 2D source/target cropper - - Parameters: - ----------- - Source has shape (batch, complex=2, slice, height, width) - Target has shape (batch, slice, height, width) - - """ - # TODO(gy): Verify target shape - self._slice_dim = -3 - - # TODO(gy): Why is this set to True and then have an if statement? - # TODO(jt): Because it might be the case we do it differently in say 3D. Just a placeholder really - use_center_slice = True - if use_center_slice: - # Source and target have a different number of slices when trimming in depth - source = source.select( - self._slice_dim, source.size(self._slice_dim) // 2 - ) # shape (batch, complex=2, height, width) - target = target.select(self._slice_dim, target.size(self._slice_dim) // 2).unsqueeze( - 1 - ) # shape (batch, complex=1, height, width) - else: - raise NotImplementedError("Only center slice cropping supported.") - - source_abs = T.modulus(source) # shape (batch, height, width) - - if not resolution or all(_ == 0 for _ in resolution): - return source_abs.unsqueeze(1), target - - source_abs = T.center_crop(source_abs, resolution).unsqueeze(1) - target_abs = T.center_crop(target, resolution) - return source_abs, target_abs diff --git a/direct/nn/unet/unet_2d.py b/direct/nn/unet/unet_2d.py index b54b6603..a96dae85 100644 --- a/direct/nn/unet/unet_2d.py +++ b/direct/nn/unet/unet_2d.py @@ -23,11 +23,11 @@ def __init__(self, in_channels: int, out_channels: int, dropout_probability: flo Parameters ---------- - in_channels : int + in_channels: int Number of input channels. - out_channels : int + out_channels: int Number of output channels. - dropout_probability : float + dropout_probability: float Dropout probability. """ super().__init__() @@ -52,7 +52,7 @@ def forward(self, input: torch.Tensor): Parameters ---------- - input : torch.Tensor + input: torch.Tensor Returns ------- @@ -77,9 +77,9 @@ def __init__(self, in_channels: int, out_channels: int): """ Parameters ---------- - in_channels : int + in_channels: int Number of input channels. - out_channels : int + out_channels: int Number of output channels. """ super().__init__() @@ -98,7 +98,7 @@ def forward(self, input: torch.Tensor): Parameters ---------- - input : torch.Tensor + input: torch.Tensor Returns ------- @@ -117,9 +117,7 @@ class UnetModel2d(nn.Module): References ---------- - .. [1] Ronneberger, Olaf, et al. “U-Net: Convolutional Networks for Biomedical Image Segmentation.” - Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015, edited by Nassir Navab et al., - Springer International Publishing, 2015, pp. 234–41. Springer Link, https://doi.org/10.1007/978-3-319-24574-4_28. + .. [1] Ronneberger, Olaf, et al. “U-Net: Convolutional Networks for Biomedical Image Segmentation.” Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015, edited by Nassir Navab et al., Springer International Publishing, 2015, pp. 234–41. Springer Link, https://doi.org/10.1007/978-3-319-24574-4_28. """ def __init__( @@ -134,15 +132,15 @@ def __init__( Parameters ---------- - in_channels : int + in_channels: int Number of input channels to the u-net. - out_channels : int + out_channels: int Number of output channels to the u-net. - num_filters : int + num_filters: int Number of output channels of the first convolutional layer. - num_pool_layers : int + num_pool_layers: int Number of down-sampling and up-sampling layers (depth). - dropout_probability : float + dropout_probability: float Dropout probability. """ super().__init__() @@ -180,7 +178,7 @@ def forward(self, input: torch.Tensor): Parameters ---------- - input : torch.Tensor + input: torch.Tensor Returns ------- @@ -235,17 +233,17 @@ def __init__( Parameters ---------- - in_channels : int + in_channels: int Number of input channels to the u-net. - out_channels : int + out_channels: int Number of output channels to the u-net. - num_filters : int + num_filters: int Number of output channels of the first convolutional layer. - num_pool_layers : int + num_pool_layers: int Number of down-sampling and up-sampling layers (depth). - dropout_probability : float + dropout_probability: float Dropout probability. - norm_groups : int, + norm_groups: int, Number of normalization groups. """ super().__init__() @@ -307,7 +305,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: Parameters ---------- - input : torch.Tensor + input: torch.Tensor Returns ------- @@ -345,21 +343,21 @@ def __init__( Parameters ---------- - forward_operator : Callable + forward_operator: Callable Forward Operator. - backward_operator : Callable + backward_operator: Callable Backward Operator. - num_filters : int + num_filters: int Number of first layer filters. - num_pool_layers : int + num_pool_layers: int Number of pooling layers. - dropout_probability : float + dropout_probability: float Dropout probability. - skip_connection : bool + skip_connection: bool If True, skip connection is used for the output. Default: False. - normalized : bool + normalized: bool If True, Normalized Unet is used. Default: False. - image_initialization : str + image_initialization: str Type of image initialization. Default: "zero-filled". kwargs: dict """ @@ -411,9 +409,9 @@ def forward( Parameters ---------- - masked_kspace : torch.Tensor + masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). - sensitivity_map : torch.Tensor + sensitivity_map: torch.Tensor Sensitivity map of shape (N, coil, height, width, complex=2). Default: None. Returns diff --git a/direct/nn/unet/unet_engine.py b/direct/nn/unet/unet_engine.py index 39bb08ba..3502f9f5 100644 --- a/direct/nn/unet/unet_engine.py +++ b/direct/nn/unet/unet_engine.py @@ -155,10 +155,12 @@ def l1_loss(source, reduction="mean", **data): """ Calculate L1 loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -170,10 +172,12 @@ def l2_loss(source, reduction="mean", **data): """ Calculate L2 loss (MSE) given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -185,7 +189,7 @@ def ssim_loss(source, reduction="mean", **data): """ Calculate SSIM loss given source and target. - Parameters: + Parameters ----------- Source: shape (batch, complex=2, height, width) Data: Contains key "target" with value a tensor of shape (batch, height, width) @@ -234,11 +238,11 @@ def evaluate( Parameters ---------- - data_loader : DataLoader - loss_fns : Dict[str, Callable], optional - regularizer_fns : Dict[str, Callable], optional - crop : str, optional - is_validation_process : bool + data_loader: DataLoader + loss_fns: Dict[str, Callable], optional + regularizer_fns: Dict[str, Callable], optional + crop: str, optional + is_validation_process: bool Returns ------- @@ -438,10 +442,12 @@ def cropper(self, source, target, resolution): """ 2D source/target cropper - Parameters: - ----------- - Source has shape (batch, height, width) - Target has shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, height, width). + target: torch.Tensor + Has has shape (batch, height, width). """ diff --git a/direct/nn/varnet/varnet.py b/direct/nn/varnet/varnet.py index 0956fc6a..19ebdc10 100644 --- a/direct/nn/varnet/varnet.py +++ b/direct/nn/varnet/varnet.py @@ -17,8 +17,7 @@ class EndToEndVarNet(nn.Module): References ---------- - .. [1] Sriram, Anuroop, et al. “End-to-End Variational Networks for Accelerated MRI Reconstruction.” - ArXiv:2004.06688 [Cs, Eess], Apr. 2020. arXiv.org, http://arxiv.org/abs/2004.06688. + .. [1] Sriram, Anuroop, et al. “End-to-End Variational Networks for Accelerated MRI Reconstruction.” ArXiv:2004.06688 [Cs, Eess], Apr. 2020. arXiv.org, http://arxiv.org/abs/2004.06688. """ @@ -34,19 +33,19 @@ def __init__( **kwargs, ): """ - Parameters: - ----------- - forward_operator : Callable + Parameters + ---------- + forward_operator: Callable Forward Operator. - backward_operator : Callable + backward_operator: Callable Backward Operator. - num_layers : int + num_layers: int Number of cascades. - regularizer_num_filters : int + regularizer_num_filters: int Regularizer model number of filters. - regularizer_num_pull_layers : int + regularizer_num_pull_layers: int Regularizer model number of pulling layers. - regularizer_dropout : float + regularizer_dropout: float Regularizer model dropout probability. """ @@ -81,16 +80,16 @@ def forward( """ Parameters ---------- - masked_kspace : torch.Tensor + masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). - sampling_mask : torch.Tensor + sampling_mask: torch.Tensor Sampling mask of shape (N, 1, height, width, 1). - sensitivity_map : torch.Tensor + sensitivity_map: torch.Tensor Sensitivity map of shape (N, coil, height, width, complex=2). Returns ------- - kspace_prediction : torch.Tensor + kspace_prediction: torch.Tensor K-space prediction of shape (N, coil, height, width, complex=2). """ @@ -113,13 +112,13 @@ def __init__( ): """ - Parameters: - ----------- - forward_operator : Callable + Parameters + ---------- + forward_operator: Callable Forward Operator. - backward_operator : Callable + backward_operator: Callable Backward Operator. - regularizer_model : nn.Module + regularizer_model: nn.Module Regularizer model. """ super().__init__() @@ -142,13 +141,13 @@ def forward( Parameters ---------- - current_kspace : torch.Tensor + current_kspace: torch.Tensor Current k-space prediction of shape (N, coil, height, width, complex=2). - masked_kspace : torch.Tensor + masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). - sampling_mask : torch.Tensor + sampling_mask: torch.Tensor Sampling mask of shape (N, 1, height, width, 1). - sensitivity_map : torch.Tensor + sensitivity_map: torch.Tensor Sensitivity map of shape (N, coil, height, width, complex=2). Returns diff --git a/direct/nn/varnet/varnet_engine.py b/direct/nn/varnet/varnet_engine.py index c595d651..7ae7e573 100644 --- a/direct/nn/varnet/varnet_engine.py +++ b/direct/nn/varnet/varnet_engine.py @@ -159,8 +159,8 @@ def l1_loss(source, reduction="mean", **data): """ Calculate L1 loss given source and target. - Parameters: - ----------- + Parameters + ---------- Source: shape (batch, complex=2, height, width) Data: Contains key "target" with value a tensor of shape (batch, height, width) @@ -174,8 +174,8 @@ def l2_loss(source, reduction="mean", **data): """ Calculate L2 loss (MSE) given source and target. - Parameters: - ----------- + Parameters + ---------- Source: shape (batch, complex=2, height, width) Data: Contains key "target" with value a tensor of shape (batch, height, width) @@ -189,8 +189,8 @@ def ssim_loss(source, reduction="mean", **data): """ Calculate SSIM loss given source and target. - Parameters: - ----------- + Parameters + ---------- Source: shape (batch, complex=2, height, width) Data: Contains key "target" with value a tensor of shape (batch, height, width) @@ -238,11 +238,11 @@ def evaluate( Parameters ---------- - data_loader : DataLoader - loss_fns : Dict[str, Callable], optional - regularizer_fns : Dict[str, Callable], optional - crop : str, optional - is_validation_process : bool + data_loader: DataLoader + loss_fns: Dict[str, Callable], optional + regularizer_fns: Dict[str, Callable], optional + crop: str, optional + is_validation_process: bool Returns ------- @@ -442,10 +442,12 @@ def cropper(self, source, target, resolution): """ 2D source/target cropper - Parameters: - ----------- - Source has shape (batch, height, width) - Target has shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, height, width) + target: torch.Tensor + Has shape (batch, height, width) """ diff --git a/direct/nn/xpdnet/xpdnet.py b/direct/nn/xpdnet/xpdnet.py index e9221dd5..0927870e 100644 --- a/direct/nn/xpdnet/xpdnet.py +++ b/direct/nn/xpdnet/xpdnet.py @@ -19,8 +19,7 @@ class XPDNet(CrossDomainNetwork): References ---------- - .. [1] Ramzi, Zaccharie, et al. “XPDNet for MRI Reconstruction: An Application to the 2020 FastMRI Challenge.” - ArXiv:2010.07290 [Physics, Stat], July 2021. arXiv.org, http://arxiv.org/abs/2010.07290. + .. [1] Ramzi, Zaccharie, et al. “XPDNet for MRI Reconstruction: An Application to the 2020 FastMRI Challenge.” ArXiv:2010.07290 [Physics, Stat], July 2021. arXiv.org, http://arxiv.org/abs/2010.07290. """ @@ -41,25 +40,25 @@ def __init__( Parameters ---------- - forward_operator : Callable + forward_operator: Callable Forward Operator. - backward_operator : Callable + backward_operator: Callable Backward Operator. - num_primal : int + num_primal: int Number of primal networks. - num_dual : int + num_dual: int Number of dual networks. - num_iter : int + num_iter: int Number of unrolled iterations. - use_primal_only : bool + use_primal_only: bool If set to True no dual-kspace model is used. Default: True. - image_model_architecture : str + image_model_architecture: str Primal-image model architecture. Currently only implemented for MWCNN. Default: 'MWCNN'. - kspace_model_architecture : str + kspace_model_architecture: str Dual-kspace model architecture. Currently only implemented for CONV and DIDN. - normalize : bool + normalize: bool Normalize input. Default: False. - kwargs : dict + kwargs: dict Keyword arguments for model architectures. """ if use_primal_only: diff --git a/direct/nn/xpdnet/xpdnet_engine.py b/direct/nn/xpdnet/xpdnet_engine.py index 7a24f30b..aec64650 100644 --- a/direct/nn/xpdnet/xpdnet_engine.py +++ b/direct/nn/xpdnet/xpdnet_engine.py @@ -158,10 +158,12 @@ def l1_loss(source, reduction="mean", **data): """ Calculate L1 loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -173,10 +175,12 @@ def l2_loss(source, reduction="mean", **data): """ Calculate L2 loss (MSE) given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -188,10 +192,12 @@ def ssim_loss(source, reduction="mean", **data): """ Calculate SSIM loss given source and target. - Parameters: - ----------- - Source: shape (batch, complex=2, height, width) - Data: Contains key "target" with value a tensor of shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, complex=2, height, width) + data: torch.Tensor + Contains key "target" with value a tensor of shape (batch, height, width) """ resolution = get_resolution(**data) @@ -237,11 +243,11 @@ def evaluate( Parameters ---------- - data_loader : DataLoader - loss_fns : Dict[str, Callable], optional - regularizer_fns : Dict[str, Callable], optional - crop : str, optional - is_validation_process : bool + data_loader: DataLoader + loss_fns: Dict[str, Callable], optional + regularizer_fns: Dict[str, Callable], optional + crop: str, optional + is_validation_process: bool Returns ------- @@ -441,10 +447,12 @@ def cropper(self, source, target, resolution): """ 2D source/target cropper - Parameters: - ----------- - Source has shape (batch, height, width) - Target has shape (batch, height, width) + Parameters + ---------- + source: torch.Tensor + Has shape (batch, height, width) + target: torch.Tensor + Has shape (batch, height, width) """ diff --git a/direct/predict.py b/direct/predict.py index a034b6a3..76a4736f 100644 --- a/direct/predict.py +++ b/direct/predict.py @@ -9,7 +9,6 @@ import torch from direct.common.subsample import build_masking_function -from direct.environment import Args from direct.inference import build_inference_transforms, setup_inference_save_to_h5 from direct.launch import launch from direct.utils import set_all_seeds diff --git a/direct/train.py b/direct/train.py index c68ad002..74119e6e 100644 --- a/direct/train.py +++ b/direct/train.py @@ -59,7 +59,7 @@ def get_root_of_file(filename: Union[pathlib.Path, str]): Parameters ---------- - filename : pathlib.Path or str + filename: pathlib.Path or str Returns ------- diff --git a/direct/utils/__init__.py b/direct/utils/__init__.py index f46254f4..4147af2f 100644 --- a/direct/utils/__init__.py +++ b/direct/utils/__init__.py @@ -24,12 +24,12 @@ def is_complex_data(data: torch.Tensor, complex_last: bool = True) -> bool: Parameters ---------- - data : torch.Tensor + data: torch.Tensor For 2D data the shape is assumed ([batch], [coil], height, width, [complex]) or ([batch], [coil], [complex], height, width). For 3D data the shape is assumed ([batch], [coil], slice, height, width, [complex]) or ([batch], [coil], [complex], slice, height, width). - complex_last : bool + complex_last: bool If true, will require complex axis to be at the last axis. Returns ------- @@ -69,7 +69,7 @@ def is_power_of_two(number: int) -> bool: Parameters ---------- - number : int + number: int Returns ------- @@ -84,7 +84,7 @@ def ensure_list(data: Any) -> List: Parameters ---------- - data : object + data: object Returns ------- @@ -105,7 +105,7 @@ def cast_as_path(data: Optional[Union[pathlib.Path, str]]) -> Optional[pathlib.P Parameters ---------- - data : str or pathlib.Path + data: str or pathlib.Path Returns ------- @@ -126,8 +126,8 @@ def str_to_class(module_name: str, function_name: str) -> Callable: set to 2. - Example - ------- + Examples + -------- >>> def mult(f, mul=2): >>> return f*mul @@ -140,9 +140,9 @@ def str_to_class(module_name: str, function_name: str) -> Callable: Parameters ---------- - module_name : str + module_name: str e.g. direct.data.transforms - function_name : str + function_name: str e.g. Identity Returns ------- @@ -173,9 +173,9 @@ def dict_to_device( Parameters ---------- - data : Dict[str, torch.Tensor] - device : torch.device, str - keys : List, Tuple + data: Dict[str, torch.Tensor] + device: torch.device, str + keys: List, Tuple Subselection of keys to copy. Returns @@ -193,8 +193,8 @@ def detach_dict(data: Dict[str, torch.Tensor], keys: Optional[Union[List, Tuple, Parameters ---------- - data : Dict[str, torch.Tensor] - keys : List, Tuple + data: Dict[str, torch.Tensor] + keys: List, Tuple Subselection of keys to detach Returns @@ -213,15 +213,15 @@ def reduce_list_of_dicts(data: List[Dict[str, torch.Tensor]], mode="average", di Parameters ---------- - data : List[Dict[str, torch.Tensor]]) - mode : str + data: List[Dict[str, torch.Tensor]]) + mode: str Which reduction mode, average reduces the dictionary, sum just adds while average computes the average. - divisor : None or int + divisor: None or int If given values are divided by this factor. Returns ------- - Dict[str, torch.Tensor] : Reduced dictionary. + Dict[str, torch.Tensor]: Reduced dictionary. """ if not data: return {} @@ -249,7 +249,7 @@ def merge_list_of_dicts(list_of_dicts): Parameters ---------- - list_of_dicts : List[Dict] + list_of_dicts: List[Dict] Returns ------- @@ -265,8 +265,8 @@ def evaluate_dict(fns_dict, source, target, reduction="mean"): """ Evaluate a dictionary of functions. - Example - ------- + Examples + -------- > evaluate_dict({'l1_loss: F.l1_loss, 'l2_loss': F.l2_loss}, a, b) Will return @@ -274,10 +274,10 @@ def evaluate_dict(fns_dict, source, target, reduction="mean"): Parameters ---------- - fns_dict : Dict[str, Callable] - source : torch.Tensor - target : torch.Tensor - reduction : str + fns_dict: Dict[str, Callable] + source: torch.Tensor + target: torch.Tensor + reduction: str Returns ------- @@ -292,8 +292,8 @@ def prefix_dict_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]: Parameters ---------- - data : Dict[str, Any] - prefix : str + data: Dict[str, Any] + prefix: str Returns ------- @@ -308,7 +308,7 @@ def git_hash() -> str: Returns ------- - str : the current git hash. + str: the current git hash. """ try: _git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.PIPE).decode().strip() @@ -329,8 +329,8 @@ def normalize_image(image: torch.Tensor, eps: float = 0.00001) -> torch.Tensor: Parameters ---------- - image : torch.Tensor - eps : float + image: torch.Tensor + eps: float Returns ------- @@ -358,9 +358,9 @@ def multiply_function(multiplier: float, func: Callable) -> Callable: Parameters ---------- - multiplier : float + multiplier: float Number to multiply with. - func : callable + func: callable Function to multiply. Returns @@ -419,7 +419,7 @@ def count_parameters(models: dict) -> None: Parameters ---------- - models : dict + models: dict Dictionary mapping model name to model. Returns diff --git a/direct/utils/asserts.py b/direct/utils/asserts.py index e4748033..2a983800 100644 --- a/direct/utils/asserts.py +++ b/direct/utils/asserts.py @@ -14,8 +14,8 @@ def assert_positive_integer(*variables, strict: bool = False) -> None: Parameters ---------- - variables : Any - strict : bool + variables: Any + strict: bool If true, will allow zero values. """ if not strict: @@ -37,7 +37,7 @@ def assert_same_shape(data_list: List[torch.Tensor]): Parameters ---------- - data_list : list + data_list: list List of tensors """ shape_list = set(_.shape for _ in data_list) @@ -51,12 +51,12 @@ def assert_complex(data: torch.Tensor, complex_last: bool = True) -> None: Parameters ---------- - data : torch.Tensor + data: torch.Tensor For 2D data the shape is assumed ([batch], [coil], height, width, [complex]) or ([batch], [coil], [complex], height, width). For 3D data the shape is assumed ([batch], [coil], slice, height, width, [complex]) or ([batch], [coil], [complex], slice, height, width). - complex_last : bool + complex_last: bool If true, will require complex axis to be at the last axis. Returns ------- diff --git a/direct/utils/bbox.py b/direct/utils/bbox.py index 4771af78..28709418 100644 --- a/direct/utils/bbox.py +++ b/direct/utils/bbox.py @@ -13,12 +13,12 @@ def crop_to_bbox( Parameters ---------- - data : np.ndarray or torch.tensor + data: np.ndarray or torch.tensor nD array or torch tensor. - bbox : list or tuple + bbox: list or tuple bbox of the form (coordinates, size), for instance (4, 4, 2, 1) is a patch starting at row 4, col 4 with height 2 and width 1. - pad_value : number + pad_value: number if bounding box would be out of the image, this is value the patch will be padded with. Returns @@ -73,8 +73,8 @@ def crop_to_largest( Parameters ---------- - data : List[Union[np.ndarray, torch.Tensor]] - pad_value : int + data: List[Union[np.ndarray, torch.Tensor]] + pad_value: int Returns ------- diff --git a/direct/utils/communication.py b/direct/utils/communication.py index 2f155f6c..1f8a82b2 100644 --- a/direct/utils/communication.py +++ b/direct/utils/communication.py @@ -74,7 +74,7 @@ def get_local_rank() -> int: Returns ------- - int : The rank of the current process within the local (per-machine) process group. + int: The rank of the current process within the local (per-machine) process group. """ if not torch.distributed.is_available(): @@ -162,8 +162,8 @@ def _pad_to_largest_tensor( """ Parameters ---------- - tensor : torch.Tensor - group : torch.distributed.group + tensor: torch.Tensor + group: torch.distributed.group Returns ------- @@ -196,7 +196,7 @@ def all_gather(data: object, group: Optional[torch.distributed.group] = None): Parameters ---------- - data : object + data: object Any pickleable object. group : A torch process group. By default, will use a group which contains all ranks on gloo backend. @@ -239,9 +239,9 @@ def gather( Parameters ---------- - data : object + data: object Any pickleable object - destination_rank : int + destination_rank: int Destination rank group : A torch process group. By default, will use a group which contains all ranks on gloo backend. @@ -301,11 +301,11 @@ def reduce_tensor_dict(tensors_dict: Dict[str, torch.Tensor]) -> Dict[str, torch Parameters ---------- - tensors_dict : dict + tensors_dict: dict dictionary with str keys mapping to torch tensors. Returns ------- - dict : the reduced dict. + dict: the reduced dict. """ if not tensors_dict: diff --git a/direct/utils/dataset.py b/direct/utils/dataset.py index 9ba2ec59..21ea1052 100644 --- a/direct/utils/dataset.py +++ b/direct/utils/dataset.py @@ -12,10 +12,10 @@ def get_filenames_for_datasets(cfg, files_root, data_root): Parameters ---------- - cfg : cfg-object + cfg: cfg-object cfg object having property lists having the relative paths compared to files root. - files_root : Union[str, pathlib.Path] - data_root : pathlib.Path + files_root: Union[str, pathlib.Path] + data_root: pathlib.Path Returns ------- diff --git a/direct/utils/events.py b/direct/utils/events.py index 6a0042f6..e7f64d42 100644 --- a/direct/utils/events.py +++ b/direct/utils/events.py @@ -101,15 +101,15 @@ def __init__(self, json_file: Union[Path, str], window_size: int = 2): Parameters ---------- - json_file : Union[Path, str] + json_file: Union[Path, str] Path to the JSON file. Data will be appended if it exists - window_size : int + window_size: int Window size of median smoothing for variables for which `smoothing_hint` is True. - validation : bool + validation: bool If true, will only log keys starting with val_ """ - self._file_handle = open(json_file, "a") + self._file_handle = open(json_file, "a", encoding="utf-8") self._window_size = window_size def write(self): @@ -136,11 +136,11 @@ def __init__(self, log_dir: Union[Path, str], window_size: int = 20, **kwargs): """ Parameters ---------- - log_dir : Union[Path, str] + log_dir: Union[Path, str] The directory to save the output events. - window_size : int + window_size: int The scalars will be median-smoothed by this window size. - kwargs : dict + kwargs: dict other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)` """ self._window_size = window_size @@ -247,7 +247,7 @@ def __init__(self, start_iter=0): """ Parameters ---------- - start_iter : int + start_iter: int The index to start with. """ self._history = defaultdict(HistoryBuffer) @@ -261,13 +261,13 @@ def add_image(self, img_name, img_tensor): """ Add an `img_tensor` to the `_vis_data` associated with `img_name`. - Args: - img_name (str): The name of the input_image to put into tensorboard. - img_tensor (torch.Tensor or numpy.array): An `uint8` or `float` - Tensor of shape `[channel, height, width]` where `channel` is - 3. The input_image format should be RGB. The elements in img_tensor - can either have values in [0, 1] (float32) or [0, 255] (uint8). - The `img_tensor` will be visualized in tensorboard. + Parameters + ---------- + img_name: str + The name of the input_image to put into tensorboard. + img_tensor: torch.Tensor or numpy.array + An `uint8` or `float` Tensor of shape `[channel, height, width]` where `channel` is 3. The input_image format should be RGB. The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8). The `img_tensor` will be visualized in tensorboard. + """ self._vis_data.append((img_name, img_tensor, self._iter)) @@ -284,15 +284,10 @@ def add_scalar(self, name, value, smoothing_hint=True): Parameters ---------- - name : - value : - smoothing_hint : bool - A 'hint' on whether this scalar is noisy and should be - smoothed when logged. The hint will be accessible through - `EventStorage.smoothing_hints`. A writer may ignore the hint - and apply custom smoothing rule. - It defaults to True because most scalars we save need to be smoothed to - provide any useful signal. + name: str + value: float + smoothing_hint: bool + A 'hint' on whether this scalar is noisy and should be smoothed when logged. The hint will be accessible through `EventStorage.smoothing_hints`. A writer may ignore the hint and apply custom smoothing rule. It defaults to True because most scalars we save need to be smoothed to provide any useful signal. Returns ------- @@ -315,8 +310,8 @@ def add_scalars(self, *, smoothing_hint=True, **kwargs): """ Put multiple scalars from keyword arguments. - Examples: - + Examples + -------- storage.add_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True) """ for k, v in kwargs.items(): @@ -326,13 +321,13 @@ def add_graph(self, img_name, img_tensor): """ Add an `img_tensor` to the `_vis_data` associated with `img_name`. - Args: - img_name (str): The name of the input_image to put into tensorboard. - img_tensor (torch.Tensor or numpy.array): An `uint8` or `float` - Tensor of shape `[channel, height, width]` where `channel` is - 3. The input_image format should be RGB. The elements in img_tensor - can either have values in [0, 1] (float32) or [0, 255] (uint8). - The `img_tensor` will be visualized in tensorboard. + Parameters + ---------- + img_name: str + The name of the input_image to put into tensorboard. + img_tensor: torch.Tensor or numpy.array + An `uint8` or `float` Tensor of shape `[channel, height, width]` where `channel` is 3. The input_image format should be RGB. The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8). The `img_tensor` will be visualized in tensorboard. + """ self._vis_data.append((img_name, img_tensor, self._iter)) diff --git a/direct/utils/io.py b/direct/utils/io.py index bfcab4f4..99bef6a0 100644 --- a/direct/utils/io.py +++ b/direct/utils/io.py @@ -38,7 +38,7 @@ def read_json(fn: Union[Dict, str, pathlib.Path]) -> Dict: Parameters ---------- - fn : Union[Dict, str, pathlib.Path] + fn: Union[Dict, str, pathlib.Path] Returns @@ -49,7 +49,7 @@ def read_json(fn: Union[Dict, str, pathlib.Path]) -> Dict: if isinstance(fn, dict): return fn - with open(fn, "r") as f: + with open(fn, "r", encoding="utf-8") as f: data = json.load(f) return data @@ -75,15 +75,15 @@ def write_json(fn: Union[str, pathlib.Path], data: Dict, indent=2) -> None: Parameters ---------- - fn : Path or str - data : dict + fn: Path or str + data: dict indent: int Returns ------- None """ - with open(fn, "w") as f: + with open(fn, "w", encoding="utf-8") as f: json.dump(data, f, indent=indent, cls=ArrayEncoder) @@ -93,7 +93,7 @@ def read_list(fn: Union[List, str, pathlib.Path]) -> List: Parameters ---------- - fn : Union[[list, str, pathlib.Path]] + fn: Union[[list, str, pathlib.Path]] Input text file or list, or a URL to a text file. Returns @@ -106,7 +106,7 @@ def read_list(fn: Union[List, str, pathlib.Path]) -> List: data = read_text_from_url(fn) return [_.strip() for _ in data.split("\n") if not _.startswith("#") and _ != ""] else: - with open(fn) as f: + with open(fn, "r", encoding="utf-8") as f: data = f.readlines() return [_.strip() for _ in data if not _.startswith("#")] return fn @@ -118,14 +118,14 @@ def write_list(fn: Union[str, pathlib.Path], data) -> None: Parameters ---------- - fn : Union[[list, str, pathlib.Path]] + fn: Union[[list, str, pathlib.Path]] Input text file or list - data : list or tuple + data: list or tuple Returns ------- None """ - with open(fn, "w") as f: + with open(fn, "w", encoding="utf-8") as f: for line in data: f.write(f"{line}\n") @@ -196,15 +196,15 @@ def download_url( Parameters ---------- - url : str + url: str URL to download file from - root : str + root: str Directory to place downloaded file in - filename : str, optional: + filename: str, optional: Name to save the file under. If None, use the basename of the URL - md5 : str, optional + md5: str, optional MD5 checksum of the download. If None, do not check - max_redirect_hops : int, optional) + max_redirect_hops: int, optional) Maximum number of redirect hops allowed """ root = os.path.expanduser(root) @@ -253,8 +253,8 @@ def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> No def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None: with zipfile.ZipFile( from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED - ) as zip: - zip.extractall(to_path) + ) as zip_file_handler: + zip_file_handler.extractall(to_path) _ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { @@ -323,9 +323,9 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: Parameters ---------- - from_path : str + from_path: str Path to the file to be decompressed. - to_path : str + to_path: str Path to the decompressed file. If omitted, ``from_path`` without compression extension is used. remove_finished (bool): If ``True``, remove the file after the extraction. @@ -358,14 +358,18 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish The archive type and a possible compression is automatically detected from the file name. If the file is compressed but not an archive the call is dispatched to :func:`decompress`. - Args: - from_path (str): Path to the file to be extracted. - to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is - used. - remove_finished (bool): If ``True``, remove the file after the extraction. + Parameters + ---------- + from_path: str + Path to the file to be extracted. + to_path: str + Path to the directory the file will be extracted to. If omitted, the directory of the file is used. + remove_finished (bool): If ``True``, remove the file after the extraction. - Returns: - (str): Path to the directory the file was extracted to. + Returns + ------- + str + Path to the directory the file was extracted to. """ if to_path is None: to_path = os.path.dirname(from_path) @@ -413,8 +417,8 @@ def read_text_from_url(url, chunk_size: int = 1024): Parameters ---------- - url : str - chunk_size : int + url: str + chunk_size: int Returns ------- @@ -447,7 +451,7 @@ def check_is_valid_url(path: str) -> bool: Parameters ---------- - path : str + path: str Returns ------- diff --git a/direct/utils/logging.py b/direct/utils/logging.py index c2b3b9f7..82c1d45f 100644 --- a/direct/utils/logging.py +++ b/direct/utils/logging.py @@ -17,11 +17,11 @@ def setup( Parameters ---------- - use_stdout : bool + use_stdout: bool Write output to standard out. - filename : PathLike + filename: PathLike Filename to write log to. - log_level : str + log_level: str Logging level as in the `python.logging` library. Returns @@ -55,3 +55,5 @@ def setup( fh.setLevel(log_level) fh.setFormatter(formatter) root.addHandler(fh) + + logging.warning("DIRECT is not intended for clinical use.") diff --git a/direct/utils/models.py b/direct/utils/models.py index 162fd11e..3f461ee2 100644 --- a/direct/utils/models.py +++ b/direct/utils/models.py @@ -11,7 +11,7 @@ def fix_state_dict_module_prefix(state_dict): Parameters ---------- - state_dict : dict + state_dict: dict state_dict of a network module Returns ------- diff --git a/direct/utils/writers.py b/direct/utils/writers.py index b039b28d..375c31e3 100644 --- a/direct/utils/writers.py +++ b/direct/utils/writers.py @@ -23,15 +23,15 @@ def write_output_to_h5( Parameters ---------- - output : dict + output: dict Dictionary with keys filenames and values torch.Tensor's with shape [depth, num_channels, ...] where num_channels is typically 1 for MRI. - output_directory : pathlib.Path - volume_processing_func : callable + output_directory: pathlib.Path + volume_processing_func: callable Function which postprocesses the volume array before saving. - output_key : str + output_key: str Name of key to save the output to. - create_dirs_if_needed : bool + create_dirs_if_needed: bool If true, the output directory and all its parents will be created. Notes diff --git a/docs/authors.rst b/docs/authors.rst new file mode 100644 index 00000000..f8bf1b1b --- /dev/null +++ b/docs/authors.rst @@ -0,0 +1 @@ +.. include:: ../authors.rst diff --git a/docs/conf.py b/docs/conf.py index 9e1de3bc..373f13ea 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -75,7 +75,7 @@ # General information about the project. project = "direct" copyright = "2021, direct contributors" -author = "Jonas Teuwen" +author = "Jonas Teuwen, George Yiasemis" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -182,7 +182,7 @@ # (source start file, target name, title, author, documentclass # [howto, manual, or own class]). latex_documents = [ - (master_doc, "direct.tex", "Direct Documentation", "Jonas Teuwen", "manual"), + (master_doc, "direct.tex", "Direct Documentation", "Jonas Teuwen", "George Yiasemis", "manual"), ] diff --git a/docs/config.rst b/docs/config.rst new file mode 100644 index 00000000..c48ff540 --- /dev/null +++ b/docs/config.rst @@ -0,0 +1,111 @@ +.. highlight:: shell + +============= +Configuration +============= + +To perform experiments for training, validation or inference, a configuration file with an extension ``.yaml`` must be defined which includes all experiments parameters such as models, datasets, etc. The following is a template for the configuration file. Accepted arguments are the parameters as defined in the ``config<>.py`` file for each function/class. For instance, accepted arguments for training are the parameters as defined in ``TrainingConfig``. A list of our configuration files can be found in the `projects <../projects/>`_ folder. + +.. code-block:: yaml + + model: + model_name: + model_parameter_1: + model_parameter_2: + ... + + additional_models: + sensitivity_model: + model_name: + ... + + physics: + forward_operator: fft2(centered=False) + backward_operator: ifft2(centered=False) + ... + + training: + datasets: + - name: Dataset1 + lists: + - + - + transforms: + estimate_sensitivity_maps: + scaling_key: + image_center_crop: + masking: + name: MaskingFunctionName + accelerations: [acceleration_1, accelaration_2, ...] + ... + ... + - name: Dataset2 + lists: + ... + transforms: + ... + masking: + name: MaskingFunctionName + accelerations: [acceleration_1, accelaration_2, ...] + ... + ... + optimizer: + lr: + batch_size: + lr_step_size: + lr_gamma: + lr_warmup_iter: + num_iterations: + validation_steps: + loss: + losses: + - function: + multiplier: + - function: + multiplier: + checkpointer: + checkpoint_steps: + metrics: [ + ... + - name: ValDataset2 + transforms: + ... + masking: + ... + text_description: + ... + - name: ... + ... + batch_size: + metrics: + - val_metric_1 + - val_metric_2 + - ... + ... + + inference: + dataset: + name: InferenceDataset + lists: ... + transforms: + masking: + ... + ... + text_description: + ... + batch_size: + ... + + logging: + tensorboard: + num_images: diff --git a/docs/datasets.rst b/docs/datasets.rst new file mode 100644 index 00000000..86693c5c --- /dev/null +++ b/docs/datasets.rst @@ -0,0 +1,117 @@ +.. highlight:: shell + +======================= +Adding your own dataset +======================= +Transforms in :code:`DIRECT` currently support only gridded data (data acquired on an equispaced grid). +Any compatible dataset should inherit from PyTorch's dataset class :code:`torch.utils.data.Dataset`. +Follow the steps below: + +- Implement your custom dataset under :code:`direct/data/datasets.py` following the template: + +.. code-block:: python + + import pathlib + + from torch.utils.data import Dataset + + logger = logging.getLogger(__name__) + + class MyNewDataset(Dataset): + """ + Information about the Dataset. + """ + + def __init__( + self, + root: pathlib.Path, + transform: Optional[Callable] = None, + filenames_filter: Optional[List[PathOrString]] = None, + text_description: Optional[str] = None, + ... + ) -> None: + """ + Initialize the dataset. + + Parameters + ---------- + root : pathlib.Path + Root directory to saved data. + transform : Optional[Callable] + Callable function that transforms the loaded data. + filenames_filter : List + List of filenames to include in the dataset. + text_description : str + Description of dataset, can be useful for logging. + ... + ... + """ + super().__init__() + + self.logger = logging.getLogger(type(self).__name__) + self.root = root + self.transform = transform + if filenames_filter: + self.logger.info(f"Attempting to load {len(filenames_filter)} filenames from list.") + filenames = filenames_filter + else: + self.logger.info(f"Parsing directory {self.root} for files.") + filenames = list(self.root.glob("*.")) + self.filenames_filter = filenames_filter + + self.text_description = text_description + + ... + + def self.get_dataset_len(self): + ... + + def __len__(self): + return self.get_dataset_len() + + def __getitem__(self, idx: int) -> Dict[str, Any]: + ... + sample = ... + ... + if self.transform: + sample = self.transform(sample) + return sample + + +Note that the :code:`__getitem__` method should output dictionaries which contain keys with values either torch.Tensors or +other metadata. Current implemented models and transforms can work with multi-coil two-dimensional data. Therefore, new datasets +should split three-dimensional data to slices of two-dimensional data. + + +- Register the new dataset in :code:`direct/data/datasets_config.py` + +.. code-block:: python + + @dataclass + class MyDatasetConfig(BaseConfig): + ... + name: str = "MyNew" + lists: List[str] = field(default_factory=lambda: []) + transforms: BaseConfig = TransformsConfig() + text_description: Optional[str] = None + ... + + +- To use your dataset, you have to request it in the :code:`config.yaml` file. The following shows an example for training. + + +.. code-block:: yaml + + training: + datasets: + - name: MyNew + lists: + - .lst + - .lst + - ... + transforms: + ... + masking: + ... + ... + diff --git a/docs/getting_started.rst b/docs/getting_started.rst new file mode 100644 index 00000000..4488f76f --- /dev/null +++ b/docs/getting_started.rst @@ -0,0 +1,30 @@ +Quick Start +=========== +This gives a brief quick start - introduction on how to download public datasets such as the Calgary-Campinas and FastMRI multi-coil MRI data and train models implemented in ``DIRECT``. + +1. Downloading and Preparing MRI datasets +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The Multi-coil Calgary-Campinas dataset can be obtained following the instructions `here `_ and the FastMRI dataset can be obtained from `here `_ by filling in their form. +Data should be arranged into training and validation folders. The testing set is not strictly required, and definitely not during training, if you do not want to compute the +test set results. + +**Note:** Preferably use a fast drive, for instance an SSD to store these files to make sure to get the maximal performance. + +2. Install ``DIRECT`` +^^^^^^^^^^^^^^^^^^^^^ + +Follow the instructions in `installation docs `_. + +3. Training and Inference +^^^^^^^^^^^^^^^^^^^^^^^^^ + +3.1 Preparing a configuration file +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +To run experiments a configuration file must be created. For a sample configuration file please refer to our `docs `_. + +3.2 Projects +~~~~~~~~~~~~ +In the `projects folder `_ folder you can find examples of baseline configurations for our experiments. + +Instructions on how to train a model or perform inference can be found in the `docs `_. diff --git a/docs/history.rst b/docs/history.rst new file mode 100644 index 00000000..77c2b6de --- /dev/null +++ b/docs/history.rst @@ -0,0 +1,2 @@ +History +======= diff --git a/docs/index.rst b/docs/index.rst index 29b60f89..b56bf32b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,18 +1,54 @@ -Welcome to DIRECT's documentation! -====================================== +:github_url: https://github.com/NKI-AI/direct/ + + +DIRECT documentation +==================== +DIRECT is a Python, end-to-end pipeline for solving Inverse Problems emerging in medical imaging. +It is built with `PyTorch `_ and stores state-of-the-art Deep Learning imaging inverse problem solvers such as denoising, +dealiasing and reconstruction. +By defining a base forward linear or non-linear operator, DIRECT can be used for training models for recovering +images such as MRIs from partially observed or noisy input data. + + .. toctree:: - :maxdepth: 2 - :caption: Contents: + :maxdepth: 1 + :caption: Index readme installation getting_started - modules - contributing authors history +.. toctree:: + :maxdepth: 1 + :caption: Training and inference + + training + inference + config + +.. toctree:: + :maxdepth: 1 + :caption: Add more Datasets and Sub-Samplers + + datasets + samplers + +.. toctree:: + :maxdepth: 1 + :caption: Model Zoo + + model_zoo.rst + +.. toctree:: + :maxdepth: 2 + :caption: API Documentation + + modules + + Indices and tables ================== * :ref:`genindex` diff --git a/docs/inference.rst b/docs/inference.rst new file mode 100644 index 00000000..2e6df708 --- /dev/null +++ b/docs/inference.rst @@ -0,0 +1,37 @@ +.. highlight:: shell + +========= +Inference +========= + +After training a model, you can use the ``direct predict`` command to perform inference. + +To perform inference on a single machine run the following code block in your linux machine: + +.. code-block:: bash + + $ direct predict --cfg --checkpoint --num-gpus [ --cfg .yaml --other-flags ] + +To predict using multiple machines run the following code (one command on each machine): + +.. code-block:: bash + + (machine0)$ direct predict --cfg --checkpoint --machine-rank 0 --num-machines 2 --dist-url [--other-flags] + (machine1)$ direct predict --cfg --checkpoint --machine-rank 1 --num-machines 2 --dist-url [--other-flags] + +The ``cfg_path_or_url`` should point to a configuration file that includes all the model parameters used for the trained model checkpoint ``checkpoint_path_or_url`` and should also include an inference configuration as follows: + +.. code-block:: yaml + + inference: + dataset: + name: InferenceDataset + lists: ... + transforms: + masking: + ... + ... + text_description: + ... + batch_size: + ... diff --git a/docs/installation.rst b/docs/installation.rst index dfcd67ad..fa8f194e 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,51 +1 @@ -.. highlight:: shell - -============ -Installation -============ - - -Stable release --------------- - -To install DIRECT, run this command in your terminal: - -.. code-block:: console - - $ pip install direct - -This is the preferred method to install DIRECT, as it will always install the most recent stable release. - -If you don't have `pip`_ installed, this `Python installation guide`_ can guide -you through the process. - -.. _pip: https://pip.pypa.io -.. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ - - -From sources ------------- - -The sources for DIRECT can be downloaded from the `Github repo`_. - -You can either clone the public repository: - -.. code-block:: console - - $ git clone git://github.com/NKI-AI/direct - -Or download the `tarball`_: - -.. code-block:: console - - $ curl -OJL https://github.com/NKI-AI/direct/tarball/master - -Once you have a copy of the source, you can install it with: - -.. code-block:: console - - $ python setup.py install - - -.. _Github repo: https://github.com/NKI-AI/direct -.. _tarball: https://github.com/NKI-AI/direct/tarball/master +.. include:: ../installation.rst diff --git a/docs/model_zoo.rst b/docs/model_zoo.rst new file mode 100644 index 00000000..7285bd15 --- /dev/null +++ b/docs/model_zoo.rst @@ -0,0 +1,175 @@ +.. role:: raw-html-m2r(raw) + :format: html + + +DIRECT Model Zoo and Baselines +============================== + +Introduction +------------ + +This file documents baselines created with the DIRECT project. You can download the parameters and weights of these +models in a ``.zip`` file by pressing on the hyperlink of the checkpoint. Each file contains the model checkpoint(s), a +configuration file ``config.yaml`` with the model parameters used to load the model for inference and validation metrics. + +How to read the tables +---------------------- + + +* "Name" refers to the name of the config file which is saved in ``projects/{project_name}/configs/{name}.yaml`` +* Checkpoint is the integer representing the model weights saved in ``model_{iteration}.pt`` as that iteration. + +License +------- + +All models made available through this page are licensed under the\ :raw-html-m2r:`
` +`Creative Commons Attribution-ShareAlike 3.0 license `_. + +Baselines +--------- + +Calgary-Campinas MR Image Reconstruction `Challenge `_ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Models were trained on the Calgary-Campinas brain dataset. Training included 47 multicoil (12 coils) volumes that were either 5x or 10x accelerated by retrospectively applying masks provided by the Calgary-Campinas team. + +Validation Set (12 coils, 20 Volumes) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + + * - Model + - Name + - Acceleration + - Checkpoint + - SSIM + - pSNR + - VIF + * - RecurrentVarNet + - recurrentvarnet + - 5x + - `148500 <>`_ + - 0.943 + - 36.1 + - 0.964 + * - RecurrentVarNet + - recurrentvarnet + - 10x + - `107000 <>`_ + - 0.911 + - 33.0 + - 0.926 + * - LPDNet + - lpd + - 5x + - `96000 `_ + - 0.937 + - 35.6 + - 0.953 + * - LPDNet + - lpd + - 10x + - `97000 `_ + - 0.901 + - 32.2 + - 0.919 + * - RIM + - rim + - 5x + - `89000 `_ + - 0.932 + - 35.0 + - 0.964 + * - RIM + - rim + - 10x + - `63000 `_ + - 0.891 + - 31.7 + - 0.911 + * - VarNet + - varnet + - 5x + - `4000 `_ + - 0.917 + - 33.3 + - 0.937 + * - VarNet + - varnet + - 10x + - `3000 `_ + - 0.862 + - 29.9 + - 0.861 + * - Joint-ICNet + - jointicnet + - 5x + - `43000 `_ + - 0.904 + - 32.0 + - 0.940 + * - Joint-ICNet + - jointicnet + - 10x + - `42500 `_ + - 0.854 + - 29.4 + - 0.853 + * - XPDNet + - xpdnet + - 5x + - `16000 `_ + - 0.907 + - 32.3 + - 0.965 + * - XPDNet + - xpdnet + - 10x + - `14000 `_ + - 0.855 + - 29.7 + - 0.837 + * - KIKI-Net + - kikinet + - 5x + - `44500 `_ + - 0.888 + - 29.6 + - 0.919 + * - KIKI-Net + - kikinet + - 10x + - `44500 `_ + - 0.833 + - 27.5 + - 0.856 + * - MultiDomainNet + - multidomainnet + - 5x + - `50000 `_ + - 0.864 + - 28.7 + - 0.912 + * - MultiDomainNet + - multidomainnet + - 10x + - `50000 `_ + - 0.810 + - 26.8 + - 0.812 + * - U-Net + - unet + - 5x + - `10000 `_ + - 0.871 + - 29.5 + - 0.895 + * - U-Net + - unet + - 10x + - `6000 `_ + - 0.821 + - 27.8 + - 0.837 + diff --git a/docs/samplers.rst b/docs/samplers.rst new file mode 100644 index 00000000..3612d147 --- /dev/null +++ b/docs/samplers.rst @@ -0,0 +1,91 @@ +.. highlight:: shell + +======================= +Adding your own sampler +======================= + +:code:`DIRECT` currently supports sub-samplers only for gridded data (data acquired on an equispaced grid). +To add a new sub-sampler follow the steps below: + +- Implement your custom sampler under :code:`direct/common/subsample.py` following the template: + +.. code-block:: python + + class MyNewMaskFunc(BaseMaskFunc): + def __init__( + self, + accelerations: Tuple[Number, ...], + ... + ): + super().__init__( + accelerations=accelerations, + uniform_range=False, + ) + ... + + def mask_func(self, shape, return_acs=False, seed=None): + """ + Main function that outputs the sampling mask and acs_mask. + + Parameters + ---------- + + shape : iterable[int] + The shape of the mask to be created. The shape should at least 3 dimensions. + Samples are drawn along the second last dimension. + seed : int (optional) + Seed for the random number generator. Setting the seed ensures the same mask is generated + each time for the same shape. + return_acs : bool + Return the autocalibration signal region as a mask. + + Returns + ------- + torch.Tensor : the sampling mask + + """ + if len(shape) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + with temp_seed(self.rng, seed): + num_rows = shape[-3] + num_cols = shape[-2] + center_fraction, acceleration = self.choose_acceleration() + + # Create the mask of shape (1, nx, ny, 1) + mask = ... + + if return_acs: + acs_mask = ... + return torch.from_numpy(acs_mask) + ... + + return torch.from_numpy(mask) + + +Ideally, your sub-sampler should be able to initialise only with the :code:`accelerations` argument. Otherwise, update :code:`direct/common/subsample_config.py` accordingly with any new keys needed to initialise +your sub-sampler: + +.. code-block:: python + + @dataclass + class MaskingConfig(BaseConfig): + ... + : ... = ... + + +- To use your sub-sampler, you have to request it in the :code:`config.yaml` file. The following shows an example for training: + + +.. code-block:: yaml + + training: + datasets: + - name: ... + ... + transforms: + ... + masking: + name: MyNew + accelerations: [...] + ... diff --git a/docs/training.rst b/docs/training.rst new file mode 100644 index 00000000..4bb81ca3 --- /dev/null +++ b/docs/training.rst @@ -0,0 +1,30 @@ +.. highlight:: shell + +======== +Training +======== + +After `installing <../installation.rst>`_ the software and downloading the training and validation data to ``, to train a model you can run the following to train a model you can use the ``direct train`` command. + +To train on a single machine run the following code block in your linux machine: + +.. code-block:: bash + + direct train /Train/ /Val/ --num-gpus --cfg [--other-flags] + +To train on multiple machines run the following code (one command on each machine): + +.. code-block:: bash + + (machine0)$ direct train /Train/ /Val/ --machine-rank 0 --num-machines 2 --dist-url [--other-flags] + (machine1)$ direct train /Train/ /Val/ --machine-rank 1 --num-machines 2 --dist-url [--other-flags] + + +The above command will start the training and will create an experiment directory in ``/base_``. If you are performing an experiment on a CPU (not recommended) replace ``--num-gpus `` with ``--device 'cpu:0'``. + +In ``/base_`` there will be stored the logs of the experiment, model checkpoints (e.g. ``model_.pt``), training and validation metrics, and a ``config.yaml`` file which includes all the configuration parameters of the experiment (as stated in the ``yaml`` file ````). + + +Training model configurations can be found in the ``projects`` folder. + +During training, training loss, validation metrics and validation image predictions are logged. Additionally, Tensorboard allows for visualization of the above. diff --git a/getting_started.md b/getting_started.md deleted file mode 100644 index f334356b..00000000 --- a/getting_started.md +++ /dev/null @@ -1,37 +0,0 @@ -# Using DIRECT - -This document gives a brief introduction on how to train a Recurrent Inference Machine on the single coil -FastMRI knee dataset. - -- For general information about DIRECT, please see [`README.md`](README.md). -- For installation instructions for DIRECT, please see [`install.md`](install.md). - -## Notebooks -Example [notebooks](notebooks) are provided. -- [FastMRIDataset](notebooks/FastMRIDataset.ipynb): in this notebook the functionality of the `FastMRIDataset` class is -described. - - -## Training -### 1. Prepare dataset -The dataset can be obtained from https://fastmri.org by filling in their form, download the singlecoil knee train and validation - data using the `curl` command they provide in the e-mail you will receive. Unzip the files using: - -```shell -tar xvf singlecoil_train.tar.gz -tar xfv singlecoil_val.tar.gz -``` -The testing set is not strictly required, and definitely not during training, if you do not want to compute the -test set results. - -**Note:** Preferably use a fast drive, for instance an SSD to store these files to make sure to get the maximal performance. - -#### 1.1 Generate metadata -As you will likely train several models on the same dataset it might be convenient to compile a dataset description. - -**TODO:** Add dataset description. - - -### 2. Build docker engine -Follow the instructions in the [docker](docker) subfolder, and make sure to mount the data and output directory -(using `--volume`). diff --git a/install.md b/install.md deleted file mode 100644 index d69c8565..00000000 --- a/install.md +++ /dev/null @@ -1,42 +0,0 @@ -# Installation - -## Requirements -- CUDA 10.2 supported GPU. -- Linux with Python ≥ 3.8 -- PyTorch ≥ 1.6 - -## Install using Docker - -We provide a [Dockerfile](docker) which install DIRECT with a few commands. While recommended due to the use of specific -pytorch features, DIRECT should also work in a virtual environment. - -## Install using `conda` - - -1. First, install conda. Here is a guide on how to install conda on Linux if you don't already have it [here](https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html). If you downloaded conda for the first time it is possible that you will need to restart your machine. Once you have conda, create a python 3.9 conda environment: -``` -conda create -n myenv python=3.9 -``` -Then, activate the virtual environment `myenv` you created where you will install the software: -``` -conda activate myenv -``` - -2. If you are using GPUs, cuda is required for the project to run. To install [PyTorch](https://pytorch.org/get-started/locally/) with cuda run: -``` -pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html -``` -**otherwise**, install the CPU PyTorch installation (not recommended): -``` -pip3 install torch==1.10.0+cpu torchvision==0.11.1+cpu torchaudio==0.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html -``` - -3. Clone the repository using `git clone` and navigate to `direct/direct/` and run -``` -python3 setup.py install -``` -This will install `direct` as a python module. - -## Common Installation Issues -If you met issues using DIRECT, please first update the repository to the latest version, and rebuild the docker. When -this does not work, create a GitHub issue so we can see whether this is a bug, or an installation problem. diff --git a/installation.rst b/installation.rst new file mode 100644 index 00000000..33a4bfd9 --- /dev/null +++ b/installation.rst @@ -0,0 +1,62 @@ + +Installation +============ + +Requirements +------------ + + +* CUDA 10.2 supported GPU. +* Linux with Python ≥ 3.8 +* PyTorch ≥ 1.6 + +Install using Docker +-------------------- + +We provide a `Dockerfile `_ which install DIRECT with a few commands. While recommended due to the use of specific +pytorch features, DIRECT should also work in a virtual environment. + +Install using ``conda`` +--------------------------- + + +#. + First, install conda. Here is a guide on how to install conda on Linux if you don't already have it `here `_. If you downloaded conda for the first time it is possible that you will need to restart your machine. Once you have conda, create a python 3.9 conda environment: + + .. code-block:: + + conda create -n myenv python=3.9 + + Then, activate the virtual environment ``myenv`` you created where you will install the software: + + .. code-block:: + + conda activate myenv + +#. + If you are using GPUs, cuda is required for the project to run. To install `PyTorch `_ with cuda run: + + .. code-block:: + + pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html + + **otherwise**\ , install the CPU PyTorch installation (not recommended): + + .. code-block:: + + pip3 install torch==1.10.0+cpu torchvision==0.11.1+cpu torchaudio==0.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html + +#. + Clone the repository using ``git clone`` and navigate to ``direct/direct/`` and run + + .. code-block:: + + python3 setup.py install + + This will install ``direct`` as a python module. + +Common Installation Issues +-------------------------- + +If you met issues using DIRECT, please first update the repository to the latest version, and rebuild the docker. When +this does not work, create a GitHub issue so we can see whether this is a bug, or an installation problem. diff --git a/logo/direct_logo_horizontal.svg b/logo/direct_logo_horizontal.svg new file mode 100644 index 00000000..ec4bbdf1 --- /dev/null +++ b/logo/direct_logo_horizontal.svg @@ -0,0 +1,4589 @@ + + + + diff --git a/logo/direct_logo_square.svg b/logo/direct_logo_square.svg new file mode 100644 index 00000000..4cdfb60f --- /dev/null +++ b/logo/direct_logo_square.svg @@ -0,0 +1,1260 @@ + + + + diff --git a/model_zoo.md b/model_zoo.md deleted file mode 100644 index ca9c933a..00000000 --- a/model_zoo.md +++ /dev/null @@ -1,44 +0,0 @@ - -# DIRECT Model Zoo and Baselines - -## Introduction -This file documents baselines created with the DIRECT project. You can download the parameters and weights of these -models in a `.zip` file by pressing on the hyperlink of the checkpoint. Each file contains the model checkpoint(s), a -configuration file `config.yaml` with the model parameters used to load the model for inference and validation metrics. - -## How to read the tables -* "Name" refers to the name of the config file which is saved in `projects/{project_name}/configs/{name}.yaml` -* Checkpoint is the integer representing the model weights saved in `model_{iteration}.pt` as that iteration. - - -## License -All models made available through this page are licensed under the -[Creative Commons Attribution-ShareAlike 3.0 license](https://creativecommons.org/licenses/by-sa/3.0/). - -## Baselines -### Calgary-Campinas MR Image Reconstruction [Challenge](https://sites.google.com/view/calgary-campinas-dataset/mr-reconstruction-challenge) - -Models were trained on the Calgary-Campinas brain dataset. Training included 47 multicoil (12 coils) volumes that were either 5x or 10x accelerated by retrospectively applying masks provided by the Calgary-Campinas team. - -#### Validation Set (12 coils, 20 Volumes) - -| Model | Name | Acceleration | Checkpoint | SSIM | pSNR | VIF | -|---------------|---------------|--------------|-----------------------------------------------------------------------|-------|------|-------| -|RecurrentVarNet|recurrentvarnet| 5x | [148500]() | 0.943 | 36.1 | 0.964 | -|RecurrentVarNet|recurrentvarnet| 10x | [107000]() | 0.911 | 33.0 | 0.926 | -|LPDNet | lpd | 5x | [96000](https://s3.aiforoncology.nl/direct-project/lpdnet.zip) | 0.937 | 35.6 | 0.953 | -|LPDNet | lpd | 10x | [97000](https://s3.aiforoncology.nl/direct-project/lpdnet.zip) | 0.901 | 32.2 | 0.919 | -|RIM | rim | 5x | [89000](https://s3.aiforoncology.nl/direct-project/rim.zip) | 0.932 | 35.0 | 0.964 | -|RIM | rim | 10x | [63000](https://s3.aiforoncology.nl/direct-project/rim.zip) | 0.891 | 31.7 | 0.911 | -|VarNet | varnet | 5x | [4000](https://s3.aiforoncology.nl/direct-project/varnet.zip) | 0.917 | 33.3 | 0.937 | -|VarNet | varnet | 10x | [3000](https://s3.aiforoncology.nl/direct-project/varnet.zip) | 0.862 | 29.9 | 0.861 | -|Joint-ICNet | jointicnet | 5x | [43000](https://s3.aiforoncology.nl/direct-project/jointicnet.zip) | 0.904 | 32.0 | 0.940 | -|Joint-ICNet | jointicnet | 10x | [42500](https://s3.aiforoncology.nl/direct-project/jointicnet.zip) | 0.854 | 29.4 | 0.853 | -|XPDNet | xpdnet | 5x | [16000](https://s3.aiforoncology.nl/direct-project/xpdnet.zip) | 0.907 | 32.3 | 0.965 | -|XPDNet | xpdnet | 10x | [14000](https://s3.aiforoncology.nl/direct-project/xpdnet.zip) | 0.855 | 29.7 | 0.837 | -|KIKI-Net | kikinet | 5x | [44500](https://s3.aiforoncology.nl/direct-project/kikinet.zip) | 0.888 | 29.6 | 0.919 | -|KIKI-Net | kikinet | 10x | [44500](https://s3.aiforoncology.nl/direct-project/kikinet.zip) | 0.833 | 27.5 | 0.856 | -|MultiDomainNet |multidomainnet | 5x | [50000](https://s3.aiforoncology.nl/direct-project/multidomainnet.zip)| 0.864 | 28.7 | 0.912 | -|MultiDomainNet |multidomainnet | 10x | [50000](https://s3.aiforoncology.nl/direct-project/multidomainnet.zip)| 0.810 | 26.8 | 0.812 | -|U-Net | unet | 5x | [10000](https://s3.aiforoncology.nl/direct-project/unet.zip) | 0.871 | 29.5 | 0.895 | -|U-Net | unet | 10x | [6000](https://s3.aiforoncology.nl/direct-project/unet.zip) | 0.821 | 27.8 | 0.837 | diff --git a/projects/spie_radial_subsampling/plot_zoomed.py b/projects/spie_radial_subsampling/plot_zoomed.py index 3328b5fb..1bcb7fcc 100644 --- a/projects/spie_radial_subsampling/plot_zoomed.py +++ b/projects/spie_radial_subsampling/plot_zoomed.py @@ -4,40 +4,40 @@ def zoom_in_rectangle(img, ax, zoom, rectangle_xy, rectangle_width, rectangle_height, **kwargs): """ - Parameters: - ----------- - img: array-like - The image data. - ax: Axes - Axes to place the inset axes. - zoom: float - Scaling factor of the data axes. zoom > 1 will enlargen the coordinates (i.e., "zoomed in"), - while zoom < 1 will shrink the coordinates (i.e., "zoomed out"). - rectangle_xy: (float or int, float or int) - The anchor point of the rectangle to be zoomed. - rectangle_width: float or int - Rectangle to be zoomed width. - rectangle_height: float or int - Rectangle to be zoomed height. + Parameters + ---------- + img: array-like + The image data. + ax: Axes + Axes to place the inset axes. + zoom: float + Scaling factor of the data axes. zoom > 1 will enlargen the coordinates (i.e., "zoomed in"), + while zoom < 1 will shrink the coordinates (i.e., "zoomed out"). + rectangle_xy: (float or int, float or int) + The anchor point of the rectangle to be zoomed. + rectangle_width: float or int + Rectangle to be zoomed width. + rectangle_height: float or int + Rectangle to be zoomed height. - Other Parameters: - ----------------- - cmap: str or Colormap, default 'gray' - The Colormap instance or registered colormap name used to map scalar data to colors. - zoomed_inset_loc: int or str, default: 'upper right' - Location to place the inset axes. - zoomed_inset_lw: float or None, default 1 - Zoomed inset axes linewidth. - zoomed_inset_col: float or None, default black - Zoomed inset axes color. - mark_inset_loc1: int or str, default is 1 - First location to place line connecting box and inset axes. - mark_inset_loc2: int or str, default is 3 - Second location to place line connecting box and inset axes. - mark_inset_lw: float or None, default None - Linewidth of lines connecting box and inset axes. - mark_inset_ec: color or None - Color of lines connecting box and inset axes. + Other Parameters + ---------------- + cmap: str or Colormap, default 'gray' + The Colormap instance or registered colormap name used to map scalar data to colors. + zoomed_inset_loc: int or str, default: 'upper right' + Location to place the inset axes. + zoomed_inset_lw: float or None, default 1 + Zoomed inset axes linewidth. + zoomed_inset_col: float or None, default black + Zoomed inset axes color. + mark_inset_loc1: int or str, default is 1 + First location to place line connecting box and inset axes. + mark_inset_loc2: int or str, default is 3 + Second location to place line connecting box and inset axes. + mark_inset_lw: float or None, default None + Linewidth of lines connecting box and inset axes. + mark_inset_ec: color or None + Color of lines connecting box and inset axes. """ axins = zoomed_inset_axes(ax, zoom, loc=kwargs.get("zoomed_inset_loc", 1)) diff --git a/setup.cfg b/setup.cfg index a15a3316..042cb36f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,16 +1,16 @@ [bumpversion] -current_version = 1.0.0-dev0 +current_version = 1.0.0 commit = True tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+)(?P\d+))? -serialize = +serialize = {major}.{minor}.{patch}-{release}{build} {major}.{minor}.{patch} [bumpversion:part:release] optional_value = prod first_value = dev -values = +values = dev prod diff --git a/setup.py b/setup.py index 0e7de1f0..a2dc8d0c 100644 --- a/setup.py +++ b/setup.py @@ -11,13 +11,13 @@ version = ast.parse(line).body[0].value.s # type: ignore break -with open("README.md") as readme_file: +with open("README.rst") as readme_file: readme = readme_file.read() setup( - author="Jonas Teuwen", - author_email="j.teuwen@nki.nl", + author="Jonas Teuwen, George Yiasemis", + author_email="j.teuwen@nki.nl, g.yiasemis@nki.nl", python_requires=">=3.8", classifiers=[ "Development Status :: 5 - Production/Stable", @@ -35,16 +35,16 @@ ], }, install_requires=[ - "numpy>=1.20.0", - "h5py>=2.10.0", - "omegaconf>=2.0.0", + "numpy>=1.21.2", + "h5py>=3.6.0", + "omegaconf>=2.1.1", "torch==1.10.0", "torchvision", - "scikit-image>=0.18.1", - "scikit-learn>=0.24.2", + "scikit-image>=0.19.0", + "scikit-learn>=1.0.1", "pyxb==1.2.6", - "ismrmrd==1.9.1", - "tensorboard>=2.5.0", + "ismrmrd==1.9.5", + "tensorboard>=2.7.0", "tqdm", ], extras_require={ @@ -55,7 +55,6 @@ "myst_parser", "sphinx-book-theme", "pylint", - "sewar", "packaging", ], },