Skip to content

Commit

Permalink
Fix high memory consumption and code quality improvements (#174)
Browse files Browse the repository at this point in the history
* MRI models engines merged to `MRIModelEngine`. Each model_engine implements `_do_iteration` method. 
* Fixes to load off memory during validation (Closes #171) 
* Main engine `predict` and `evaluate`  methods refactored 
* Disables stochasticity - `set_all_seeds` function refactored (Closes #162)
* Memory leak identified and fixed (`h5py 3.6.0` -> `3.3.0`)
* Implement tests for the `train` method of `Engine`
* `Black 22.1.0` fixes

Co-authored:
* @jonasteuwen
  • Loading branch information
georgeyiasemis authored Feb 10, 2022
1 parent f624a66 commit e430521
Show file tree
Hide file tree
Showing 108 changed files with 2,312 additions and 4,839 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ jobs:
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: psf/black@stable
- uses: psf/black@22.1.0
17 changes: 17 additions & 0 deletions .prospector.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pep257:
disable:
- D203 # Conflict with numpydocs
- D213 # Conflict with numpydocs
- D212 # Conflict with numpydocs

pep8:
disable:
- E501 # Handled by black
- W605 # Conflict with numpydocs - invalid escape sequence

pylint:
disable:
- import-outside-toplevel
options:
max-args: 20

27 changes: 27 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: Yiasemis
given-names: George
email: [email protected]
orcid: https://orcid.org/0000-0002-1348-8987
- family-names: Moriakov
given-names: Nikita
email: [email protected]
orcid: https://orcid.org/0000-0002-7127-1006
- family-names: Karkalousos
given-names: Dimitrios
email: [email protected]
orcid: https://orcid.org/0000-0001-5983-0322
- family-names: Caan
given-names: Matthan
email: [email protected]
orcid: https://orcid.org/0000-0002-5162-8880
- family-names: Teuwen
given-names: Jonas
email: [email protected]
orcid: https://orcid.org/0000-0002-1825-1428
title: "DIRECT: Deep Image REConstruction Toolkit"
url: "https://github.com/NKI-AI/direct"
version: 1.0.0
date-released: 2022-07-01
4 changes: 4 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
:target: https://github.com/NKI-AI/direct/actions/workflows/black.yml
:alt: black

.. image:: https://api.codacy.com/project/badge/Grade/1c55d497dead4df69d6f256da51c98b7
:target: https://app.codacy.com/gh/NKI-AI/direct?utm_source=github.com&utm_medium=referral&utm_content=NKI-AI/direct&utm_campaign=Badge_Grade_Settings
:alt: codacy

.. image:: https://codecov.io/gh/NKI-AI/direct/branch/main/graph/badge.svg?token=STYAUFCKJY
:target: https://codecov.io/gh/NKI-AI/direct
:alt: codecov
Expand Down
10 changes: 10 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
coverage:
range: 50..90 # coverage lower than 50 is red, higher than 90 green, between color code

status:
project: # settings affecting project coverage
default:
target: auto # auto % coverage target
threshold: 1% # allow for 1% reduction of coverage without failing

# do not run coverage on patch nor changes
patch: false
2 changes: 1 addition & 1 deletion direct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Copyright (c) DIRECT Contributors

__author__ = """direct contributors"""
__version__ = "1.0.0"
__version__ = "1.0.1-dev0"
41 changes: 20 additions & 21 deletions direct/checkpointer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
"""Checkpointer module. Handles all logic related to checkpointing."""
"""Checkpointer module.
Handles all logic related to checkpointing.
"""
import datetime
import logging
import pathlib
Expand Down Expand Up @@ -31,7 +34,10 @@


class Checkpointer:
"""Main Checkpointer module. Handles writing and restoring from checkpoints of modules and submodels."""
"""Main Checkpointer module.
Handles writing and restoring from checkpoints of modules and submodules.
"""

def __init__(
self,
Expand Down Expand Up @@ -110,8 +116,7 @@ def load_from_path(
checkpointable_objects: Optional[Dict[str, nn.Module]] = None,
only_models: bool = False,
) -> Dict:
"""
Load a checkpoint from a path
"""Load a checkpoint from a path.
Parameters
----------
Expand All @@ -129,7 +134,6 @@ def load_from_path(
checkpoint = self._load_checkpoint(checkpoint_path)
checkpointable_objects = self.checkpointables if not checkpointable_objects else checkpointable_objects

# TODO: Model and other checkpointable objects should be treated on the same footing
self.logger.info("Loading model...")
self._load_model(self.model, checkpoint["model"])

Expand All @@ -138,17 +142,17 @@ def load_from_path(
continue

if key not in checkpoint:
self.logger.warning(f"Requested to load {key}, but this was not stored.")
self.logger.warning("Requested to load %s, but this was not stored.", key)
continue

if key.endswith("__") and key.startswith("__"):
continue

self.logger.info(f"Loading {key}...")
self.logger.info("Loading %s...", key)
obj = self.checkpointables[key]
state_dict = checkpoint.pop(key)
if re.match(self.model_regex, key):
self.logger.debug(f"key {key} matches regex {self.model_regex}.")
self.logger.debug("key %s matches regex %s.", key, self.model_regex)
self._load_model(obj, state_dict) # type: ignore
else:
obj.load_state_dict(state_dict) # type: ignore
Expand All @@ -162,7 +166,7 @@ def _load_model(self, obj, state_dict):
if incompatible.missing_keys:
raise NotImplementedError
if incompatible.unexpected_keys:
self.logger.warning(f"Unexpected keys provided which cannot be loaded: {incompatible.unexpected_keys}.")
self.logger.warning("Unexpected keys provided which cannot be loaded: %s.", incompatible.unexpected_keys)

def load_models_from_file(self, checkpoint_path: PathOrString) -> None:
_ = self.load_from_path(checkpoint_path, only_models=True)
Expand All @@ -182,12 +186,12 @@ def save(self, iteration: int, **kwargs: Dict[str, str]) -> None:
elif isinstance(obj, get_args(HasStateDict)):
data[key] = obj.state_dict() # type: ignore
else:
self.logger.warning(f"Value of key {key} has no state_dict.")
self.logger.warning("Value of key %s has no state_dict.", key)

data.update(kwargs)

checkpoint_path = self.save_directory / f"model_{iteration}.pt"
self.logger.info(f"Saving checkpoint to: {checkpoint_path}.")
self.logger.info("Saving checkpoint to: %s.", checkpoint_path)

data["__datetime__"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

Expand All @@ -199,8 +203,7 @@ def save(self, iteration: int, **kwargs: Dict[str, str]) -> None:
f.write(str(iteration)) # type: ignore

def _load_checkpoint(self, checkpoint_path: PathOrString) -> Dict:
"""
Load a checkpoint from path or string
"""Load a checkpoint from path or string.
Parameters
----------
Expand All @@ -212,25 +215,21 @@ def _load_checkpoint(self, checkpoint_path: PathOrString) -> Dict:
"""
# Check if the path is an URL
if check_is_valid_url(str(checkpoint_path)):
self.logger.info(f"Initializing from remote checkpoint {checkpoint_path}...")
self.logger.info("Initializing from remote checkpoint %s...", checkpoint_path)
checkpoint_path = self._download_or_load_from_cache(checkpoint_path)
self.logger.info(f"Loading downloaded checkpoint {checkpoint_path}.")
self.logger.info("Loading downloaded checkpoint %s.", checkpoint_path)

checkpoint_path = pathlib.Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Requested to load {checkpoint_path}, but does not exist.")

self.logger.info(f"Loaded checkpoint path: {checkpoint_path}.")
self.logger.info("Loaded checkpoint path: %s.", checkpoint_path)

try:
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))

except UnpicklingError as exc:
self.logger.exception(
f"Tried to load {checkpoint_path}, but was unable to unpickle: {exc}.",
checkpoint_path=checkpoint_path,
exc=exc,
)
self.logger.exception("Tried to load %s, but was unable to unpickle: %s.", checkpoint_path, exc)
raise

return checkpoint
Expand Down
9 changes: 5 additions & 4 deletions direct/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
"""DIRECT Command-line interface. This is the file which builds the main parser."""
"""DIRECT Command-line interface.
This is the file which builds the main parser.
"""

import argparse


def main():
"""
Console script for direct.
"""
"""Console script for direct."""
# From https://stackoverflow.com/questions/17073688/how-to-use-argparse-subparsers-correctly
root_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

Expand Down
2 changes: 1 addition & 1 deletion direct/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def register_parser(parser: argparse._SubParsersAction):
"""Register wsi commands to a root parser."""

epilog = f"""
epilog = """
Examples:
---------
Run on single machine:
Expand Down
5 changes: 2 additions & 3 deletions direct/cli/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) DIRECT Contributors
import argparse
import os

from direct.cli.utils import is_file
from direct.utils.io import upload_to_s3

Expand All @@ -19,9 +20,7 @@ def upload_from_argparse(args: argparse.Namespace): # pragma: no cover


class BaseArgs(argparse.ArgumentParser): # pragma: no cover
"""
Defines global default arguments.
"""
"""Defines global default arguments."""

def __init__(self, epilog=None, add_help=True, **overrides):
"""
Expand Down
Loading

0 comments on commit e430521

Please sign in to comment.