Skip to content

Commit

Permalink
Download checkpoints and allow remote configs (#133, Closes #135)
Browse files Browse the repository at this point in the history
- Version bump Pytorch and torchvision
- Remove Calgary Campinas masks and download them when required
- Allow to initialise from checkpoint available at URL
- Allow to have configurations supplied as URL.

Closes #135.
  • Loading branch information
jonasteuwen authored Dec 10, 2021
1 parent 38f5af3 commit 2de00ed
Show file tree
Hide file tree
Showing 30 changed files with 600 additions and 60 deletions.
59 changes: 54 additions & 5 deletions direct/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import pathlib
import re
import urllib.parse
import warnings
from pickle import UnpicklingError
from typing import Dict, Mapping, Optional, Union, get_args
Expand All @@ -14,7 +15,11 @@
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel

from direct.environment import DIRECT_MODEL_DOWNLOAD_DIR
from direct.types import HasStateDict, PathOrString
from direct.utils.io import check_is_valid_url, download_url

logger = logging.getLogger(__name__)

# TODO: Rewrite Checkpointer
# There are too many issues with typing and mypy in the checkpointer.
Expand Down Expand Up @@ -91,19 +96,35 @@ def load(
return {}

checkpoint_path = self.save_directory / f"model_{iteration}.pt"
checkpoint = self.load_from_file(checkpoint_path, checkpointable_objects)
checkpoint = self.load_from_path(checkpoint_path, checkpointable_objects)
checkpoint["iteration"] = iteration

self.checkpoint_loaded = iteration
# Return whatever is left
return checkpoint

def load_from_file(
def load_from_path(
self,
checkpoint_path: PathOrString,
checkpointable_objects: Optional[Dict[str, nn.Module]] = None,
only_models=False,
only_models: bool = False,
) -> Dict:
"""
Load a checkpoint from a path
Parameters
----------
checkpoint_path : Path or str
Path to checkpoint, either a path to a file or a path to a URL where the file can be downloaded
checkpointable_objects : dict
Dictionary mapping names to nn.Module's
only_models : bool
If true will only load the models and no other objects in the checkpoint
Returns
-------
Dictionary with loaded models.
"""
checkpoint = self._load_checkpoint(checkpoint_path)
checkpointable_objects = self.checkpointables if not checkpointable_objects else checkpointable_objects

Expand Down Expand Up @@ -140,7 +161,7 @@ def _load_model(self, obj, state_dict):
self.logger.warning(f"Unexpected keys provided which cannot be loaded: {incompatible.unexpected_keys}.")

def load_models_from_file(self, checkpoint_path: PathOrString) -> None:
_ = self.load_from_file(checkpoint_path, only_models=True)
_ = self.load_from_path(checkpoint_path, only_models=True)

def save(self, iteration: int, **kwargs: Dict[str, str]) -> None:
# For instance useful to only have the rank 0 process write to disk.
Expand Down Expand Up @@ -174,6 +195,23 @@ def save(self, iteration: int, **kwargs: Dict[str, str]) -> None:
f.write(str(iteration)) # type: ignore

def _load_checkpoint(self, checkpoint_path: PathOrString) -> Dict:
"""
Load a checkpoint from path or string
Parameters
----------
checkpoint_path : Path or str
Path to checkpoint, either a path to a file or a path to a URL where the file can be downloaded
Returns
-------
Dict loaded from checkpoint.
"""
# Check if the path is an URL
if check_is_valid_url(str(checkpoint_path)):
self.logger.info(f"Initializing from remote checkpoint {checkpoint_path}...")
checkpoint_path = _download_or_load_from_cache(checkpoint_path)
self.logger.info(f"Loading downloaded checkpoint {checkpoint_path}.")

checkpoint_path = pathlib.Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Requested to load {checkpoint_path}, but does not exist.")
Expand All @@ -185,10 +223,21 @@ def _load_checkpoint(self, checkpoint_path: PathOrString) -> Dict:

except UnpicklingError as exc:
self.logger.exception(
"Tried to load {checkpoint_path}, but was unable to unpickle: {exc}.",
f"Tried to load {checkpoint_path}, but was unable to unpickle: {exc}.",
checkpoint_path=checkpoint_path,
exc=exc,
)
raise

return checkpoint


def _download_or_load_from_cache(url: str) -> pathlib.Path:
# Get final part of url.
file_path = urllib.parse.urlparse(url).path
filename = pathlib.Path(file_path).name

cache_path = DIRECT_MODEL_DOWNLOAD_DIR / filename
download_url(url, DIRECT_MODEL_DOWNLOAD_DIR, max_redirect_hops=3)

return cache_path
15 changes: 15 additions & 0 deletions direct/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import argparse
import pathlib

from direct.utils.io import check_is_valid_url


def file_or_url(path):
if check_is_valid_url(path):
return path
path = pathlib.Path(path)
if path.is_file():
return path
raise argparse.ArgumentTypeError(f"{path} is not a valid file or url.")
Binary file removed direct/common/calgary_campinas_masks/R10_218x170.npy
Binary file not shown.
Binary file removed direct/common/calgary_campinas_masks/R10_218x174.npy
Binary file not shown.
Binary file removed direct/common/calgary_campinas_masks/R10_218x180.npy
Binary file not shown.
Binary file removed direct/common/calgary_campinas_masks/R5_218x170.npy
Binary file not shown.
Binary file removed direct/common/calgary_campinas_masks/R5_218x174.npy
Binary file not shown.
Binary file removed direct/common/calgary_campinas_masks/R5_218x180.npy
Binary file not shown.
19 changes: 18 additions & 1 deletion direct/common/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import numpy as np
import torch

from direct.environment import DIRECT_CACHE_DIR
from direct.types import Number
from direct.utils import str_to_class
from direct.utils.io import download_url

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -281,6 +283,16 @@ def mask_func(self, shape, return_acs=False, seed=None):


class CalgaryCampinasMaskFunc(BaseMaskFunc):
BASE_URL = "https://s3.aiforoncology.nl/direct-project/calgary_campinas_masks/"
MASK_MD5S = {
"R10_218x170.npy": "6e1511c33dcfc4a960f526252676f7c3",
"R10_218x174.npy": "78fe23ae5eed2d3a8ff3ec128388dcc9",
"R10_218x180.npy": "5039a6c19ac2aa3472a94e4b015e5228",
"R5_218x170.npy": "6599715103cf3d71d6e87d09f865e7da",
"R5_218x174.npy": "5bd27d2da3bf1e78ad1b65c9b5e4b621",
"R5_218x180.npy": "717b51f3155c3a64cfaaddadbe90791d",
}

# TODO: Configuration improvements, so no **kwargs needed.
def __init__(self, accelerations: Tuple[int, ...], **kwargs): # noqa
super().__init__(accelerations=accelerations, uniform_range=False)
Expand Down Expand Up @@ -319,12 +331,17 @@ def mask_func(self, shape, return_acs=False):
return torch.from_numpy(mask[choice][np.newaxis, ..., np.newaxis])

def __load_masks(self, acceleration):
masks_path = pathlib.Path(pathlib.Path(__file__).resolve().parent / "calgary_campinas_masks")
masks_path = DIRECT_CACHE_DIR / "calgary_campinas_masks"
paths = [
f"R{acceleration}_218x170.npy",
f"R{acceleration}_218x174.npy",
f"R{acceleration}_218x180.npy",
]

downloaded = [download_url(self.BASE_URL + _, masks_path, md5=self.MASK_MD5S[_]) is None for _ in paths]
if not all(downloaded):
raise RuntimeError(f"Failed to download all Calgary-Campinas masks from {self.BASE_URL}.")

output = {}
for path in paths:
shape = [int(_) for _ in path.split("_")[-1][:-4].split("x")]
Expand Down
2 changes: 1 addition & 1 deletion direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.fft

from direct.data.bbox import crop_to_bbox
from direct.utils import ensure_list, is_power_of_two, is_complex_data
from direct.utils import ensure_list, is_complex_data, is_power_of_two
from direct.utils.asserts import assert_complex, assert_same_shape


Expand Down
34 changes: 23 additions & 11 deletions direct/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,24 @@
from torch.utils import collect_env

import direct.utils.logging
from direct.utils.logging import setup
from direct.config.defaults import DefaultConfig, InferenceConfig, TrainingConfig, ValidationConfig
from direct.utils import communication, count_parameters, str_to_class
from direct.utils.io import check_is_valid_url, read_text_from_url
from direct.utils.logging import setup

logger = logging.getLogger(__name__)

# Environmental variables
DIRECT_ROOT_DIR = pathlib.Path(pathlib.Path(__file__).resolve().parent.parent)
DIRECT_CACHE_DIR = pathlib.Path(os.environ.get("DIRECT_CACHE_DIR", str(DIRECT_ROOT_DIR)))
DIRECT_MODEL_DOWNLOAD_DIR = (
pathlib.Path(os.environ.get("DIRECT_MODEL_DOWNLOAD_DIR", str(DIRECT_ROOT_DIR))) / "downloaded_models"
)


def load_model_config_from_name(model_name):
"""
Load specific configuration module for
Load specific configuration module for models based on their name.
Parameters
----------
Expand Down Expand Up @@ -195,7 +203,7 @@ def extract_names(cfg):
def setup_common_environment(
run_name,
base_directory,
cfg_filename,
cfg_pathname,
device,
machine_rank,
mixed_precision,
Expand All @@ -213,12 +221,16 @@ def setup_common_environment(
communication.synchronize() # Ensure folders are in place.

# Load configs from YAML file to check which model needs to be loaded.
cfg_from_file = OmegaConf.load(cfg_filename)
# Can also be loaded from a URL
if check_is_valid_url(cfg_pathname):
cfg_from_external_source = OmegaConf.create(read_text_from_url(cfg_pathname))
else:
cfg_from_external_source = OmegaConf.load(cfg_pathname)

# Load the default configs to ensure type safety
cfg = OmegaConf.structured(DefaultConfig)

models, models_config = load_models_into_environment_config(cfg_from_file)
models, models_config = load_models_into_environment_config(cfg_from_external_source)
cfg.model = models_config.model
del models_config["model"]
cfg.additional_models = models_config
Expand All @@ -228,25 +240,25 @@ def setup_common_environment(
cfg.validation = ValidationConfig
cfg.inference = InferenceConfig

cfg_from_file_new = cfg_from_file.copy()
for key in cfg_from_file:
cfg_from_file_new = cfg_from_external_source.copy()
for key in cfg_from_external_source:
# TODO: This does not really do a full validation.
# BODY: This will be handeled once Hydra is implemented.
if key in ["models", "additional_models"]: # Still handled separately
continue

if key in ["training", "validation", "inference"]:
if not cfg_from_file[key]:
if not cfg_from_external_source[key]:
logger.info(f"key {key} missing in config.")
continue

if key in ["training", "validation"]:
dataset_cfg_from_file = extract_names(cfg_from_file[key].datasets)
dataset_cfg_from_file = extract_names(cfg_from_external_source[key].datasets)
for idx, (dataset_name, dataset_config) in enumerate(dataset_cfg_from_file):
cfg_from_file_new[key].datasets[idx] = dataset_config
cfg[key].datasets.append(load_dataset_config(dataset_name)) # pylint: disable = E1136
else:
dataset_name, dataset_config = extract_names(cfg_from_file[key].dataset)
dataset_name, dataset_config = extract_names(cfg_from_external_source[key].dataset)
cfg_from_file_new[key].dataset = dataset_config
cfg[key].dataset = load_dataset_config(dataset_name) # pylint: disable = E1136

Expand All @@ -255,7 +267,7 @@ def setup_common_environment(
# Make configuration read only.
# TODO(jt): Does not work when indexing config lists.
# OmegaConf.set_readonly(cfg, True)
setup_logging(machine_rank, experiment_dir, run_name, cfg_filename, cfg, debug)
setup_logging(machine_rank, experiment_dir, run_name, cfg_pathname, cfg, debug)
forward_operator, backward_operator = build_operators(cfg.physics)

model, additional_models = initialize_models_from_config(cfg, models, forward_operator, backward_operator, device)
Expand Down
2 changes: 1 addition & 1 deletion direct/nn/jointicnet/jointicnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn

import direct.data.transforms as T
from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d
from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d


class JointICNet(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions direct/nn/kikinet/kikinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

import direct.data.transforms as T
from direct.nn.conv.conv import Conv2d
from direct.nn.crossdomain.multicoil import MultiCoil
from direct.nn.didn.didn import DIDN
from direct.nn.mwcnn.mwcnn import MWCNN
from direct.nn.crossdomain.multicoil import MultiCoil
from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d
from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d


class KIKINet(nn.Module):
Expand Down
8 changes: 4 additions & 4 deletions direct/nn/lpd/lpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

from typing import Callable

import torch
import torch.nn as nn

import direct.data.transforms as T
from direct.nn.conv.conv import Conv2d
from direct.nn.didn.didn import DIDN
from direct.nn.mwcnn.mwcnn import MWCNN
from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d

import torch
import torch.nn as nn
from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d


class DualNet(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions direct/nn/multidomainnet/multidomainnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from typing import Callable

import direct.data.transforms as T
from direct.nn.multidomainnet.multidomain import MultiDomainUnet2d

import torch
import torch.nn as nn

import direct.data.transforms as T
from direct.nn.multidomainnet.multidomain import MultiDomainUnet2d


class StandardizationLayer(nn.Module):
"""
Expand Down
1 change: 1 addition & 0 deletions direct/nn/recurrent/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) DIRECT Contributors

from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down
7 changes: 1 addition & 6 deletions direct/nn/recurrentvarnet/recurrentvarnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F

from direct.data.transforms import (
conjugate,
complex_multiplication,
reduce_operator,
expand_operator,
)
from direct.data.transforms import complex_multiplication, conjugate, expand_operator, reduce_operator
from direct.nn.recurrent.recurrent import Conv2dGRU


Expand Down
2 changes: 1 addition & 1 deletion direct/nn/rim/rim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch.nn.functional as F

from direct.data import transforms as T
from direct.utils.asserts import assert_positive_integer
from direct.nn.recurrent.recurrent import Conv2dGRU
from direct.utils.asserts import assert_positive_integer


class MRILogLikelihood(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion direct/nn/unet/tests/test_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from direct.data.transforms import fft2, ifft2
from direct.nn.unet.unet_2d import Unet2d, NormUnetModel2d
from direct.nn.unet.unet_2d import NormUnetModel2d, Unet2d


def create_input(shape):
Expand Down
2 changes: 1 addition & 1 deletion direct/nn/xpdnet/xpdnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import torch.nn as nn

from direct.nn.conv.conv import Conv2d
from direct.nn.crossdomain.crossdomain import CrossDomainNetwork
from direct.nn.crossdomain.multicoil import MultiCoil
from direct.nn.conv.conv import Conv2d
from direct.nn.didn.didn import DIDN
from direct.nn.mwcnn.mwcnn import MWCNN

Expand Down
2 changes: 2 additions & 0 deletions direct/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
Loading

0 comments on commit 2de00ed

Please sign in to comment.