Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download checkpoints #133

Merged
merged 33 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ff4f353
Allow URLs as input to checkpointer
jonasteuwen Nov 17, 2021
4be53ab
Merge upstream to download-checkpoints (#132)
jonasteuwen Nov 29, 2021
08dbeee
Added utilities to download models
jonasteuwen Nov 29, 2021
8618298
Preliminary fix for calgary models
jonasteuwen Nov 29, 2021
d8b42b6
Merge branch 'main' into download-checkpoints
georgeyiasemis Dec 1, 2021
0eb2a21
Added possibility to load config from URL
jonasteuwen Dec 6, 2021
5e5dcb4
move checkpointer io test to test_io
jonasteuwen Dec 6, 2021
8d49db6
move checkpointer io test to test_io
jonasteuwen Dec 6, 2021
c82ee52
Version bump for pytorch
Dec 6, 2021
8494ba6
Smoothing changes
jonasteuwen Dec 6, 2021
26a3e33
Allow read_list to read data from url
jonasteuwen Dec 7, 2021
cd32264
Allow to read lists from URL as well
jonasteuwen Dec 9, 2021
27c70b6
Small change
jonasteuwen Dec 9, 2021
9d5b0ea
Formatting and import fixes
jonasteuwen Dec 9, 2021
6e637a7
Formatting and import fixes
jonasteuwen Dec 9, 2021
be88e74
download_url already checks if all is okay
jonasteuwen Dec 9, 2021
fbd6ece
Add md5s to download
jonasteuwen Dec 10, 2021
2eca066
remove superflucuous files
jonasteuwen Dec 10, 2021
99c93bf
Allow checkpoints and configs from remote url
jonasteuwen Dec 10, 2021
cfa8d0a
Fix conflicts
Dec 10, 2021
afcc7f4
Allow checkpoints and configs from remote url
jonasteuwen Dec 10, 2021
43f861a
Allow checkpoints and configs from remote url
jonasteuwen Dec 10, 2021
97c8876
Imrpoved http error
Dec 10, 2021
eae29e5
Allow checkpoints and configs from remote url
jonasteuwen Dec 10, 2021
bd385c5
Merge branch 'download-checkpoints' of github.com:directgroup/direct …
jonasteuwen Dec 10, 2021
bdec58b
Checkpoint can be URL
jonasteuwen Dec 10, 2021
7ce3d08
dir->file
jonasteuwen Dec 10, 2021
9812548
dir->file
jonasteuwen Dec 10, 2021
bd926ac
Fix torch setup.py version
jonasteuwen Dec 10, 2021
f12f58f
Fix torch setup.py version
jonasteuwen Dec 10, 2021
85659aa
Placeholder
jonasteuwen Dec 10, 2021
085589f
Fix black
jonasteuwen Dec 10, 2021
f5326dd
Extend the help
jonasteuwen Dec 10, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions direct/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import pathlib
import re
import urllib.parse
import warnings
from pickle import UnpicklingError
from typing import Dict, Mapping, Optional, Union, get_args
Expand All @@ -14,7 +15,11 @@
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel

from direct.environment import DIRECT_MODEL_DOWNLOAD_DIR
from direct.types import HasStateDict, PathOrString
from direct.utils.io import check_is_valid_url, download_url

logger = logging.getLogger(__name__)

# TODO: Rewrite Checkpointer
# There are too many issues with typing and mypy in the checkpointer.
Expand Down Expand Up @@ -91,19 +96,35 @@ def load(
return {}

checkpoint_path = self.save_directory / f"model_{iteration}.pt"
checkpoint = self.load_from_file(checkpoint_path, checkpointable_objects)
checkpoint = self.load_from_path(checkpoint_path, checkpointable_objects)
checkpoint["iteration"] = iteration

self.checkpoint_loaded = iteration
# Return whatever is left
return checkpoint

def load_from_file(
def load_from_path(
self,
checkpoint_path: PathOrString,
checkpointable_objects: Optional[Dict[str, nn.Module]] = None,
only_models=False,
only_models: bool = False,
) -> Dict:
"""
Load a checkpoint from a path
Parameters
----------
checkpoint_path : Path or str
Path to checkpoint, either a path to a file or a path to a URL where the file can be downloaded
checkpointable_objects : dict
Dictionary mapping names to nn.Module's
only_models : bool
If true will only load the models and no other objects in the checkpoint
Returns
-------
Dictionary with loaded models.
"""
checkpoint = self._load_checkpoint(checkpoint_path)
checkpointable_objects = self.checkpointables if not checkpointable_objects else checkpointable_objects

Expand Down Expand Up @@ -140,7 +161,7 @@ def _load_model(self, obj, state_dict):
self.logger.warning(f"Unexpected keys provided which cannot be loaded: {incompatible.unexpected_keys}.")

def load_models_from_file(self, checkpoint_path: PathOrString) -> None:
_ = self.load_from_file(checkpoint_path, only_models=True)
_ = self.load_from_path(checkpoint_path, only_models=True)

def save(self, iteration: int, **kwargs: Dict[str, str]) -> None:
# For instance useful to only have the rank 0 process write to disk.
Expand Down Expand Up @@ -174,6 +195,23 @@ def save(self, iteration: int, **kwargs: Dict[str, str]) -> None:
f.write(str(iteration)) # type: ignore

def _load_checkpoint(self, checkpoint_path: PathOrString) -> Dict:
"""
Load a checkpoint from path or string
Parameters
----------
checkpoint_path : Path or str
Path to checkpoint, either a path to a file or a path to a URL where the file can be downloaded
Returns
-------
Dict loaded from checkpoint.
"""
# Check if the path is an URL
if check_is_valid_url(str(checkpoint_path)):
self.logger.info(f"Initializing from remote checkpoint {checkpoint_path}...")
checkpoint_path = _download_or_load_from_cache(checkpoint_path)
self.logger.info(f"Loading downloaded checkpoint {checkpoint_path}.")

checkpoint_path = pathlib.Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Requested to load {checkpoint_path}, but does not exist.")
Expand All @@ -185,10 +223,21 @@ def _load_checkpoint(self, checkpoint_path: PathOrString) -> Dict:

except UnpicklingError as exc:
self.logger.exception(
"Tried to load {checkpoint_path}, but was unable to unpickle: {exc}.",
f"Tried to load {checkpoint_path}, but was unable to unpickle: {exc}.",
checkpoint_path=checkpoint_path,
exc=exc,
)
raise

return checkpoint


def _download_or_load_from_cache(url: str) -> pathlib.Path:
# Get final part of url.
file_path = urllib.parse.urlparse(url).path
filename = pathlib.Path(file_path).name

cache_path = DIRECT_MODEL_DOWNLOAD_DIR / filename
download_url(url, DIRECT_MODEL_DOWNLOAD_DIR, max_redirect_hops=3)

return cache_path
15 changes: 15 additions & 0 deletions direct/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import argparse
import pathlib

from direct.utils.io import check_is_valid_url


def file_or_url(path):
if check_is_valid_url(path):
return path
path = pathlib.Path(path)
if path.is_file():
return path
raise argparse.ArgumentTypeError(f"{path} is not a valid file or url.")
Binary file removed direct/common/calgary_campinas_masks/R10_218x170.npy
Binary file not shown.
Binary file removed direct/common/calgary_campinas_masks/R10_218x174.npy
Binary file not shown.
Binary file removed direct/common/calgary_campinas_masks/R10_218x180.npy
Binary file not shown.
Binary file removed direct/common/calgary_campinas_masks/R5_218x170.npy
Binary file not shown.
Binary file removed direct/common/calgary_campinas_masks/R5_218x174.npy
Binary file not shown.
Binary file removed direct/common/calgary_campinas_masks/R5_218x180.npy
Binary file not shown.
19 changes: 18 additions & 1 deletion direct/common/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import numpy as np
import torch

from direct.environment import DIRECT_CACHE_DIR
from direct.types import Number
from direct.utils import str_to_class
from direct.utils.io import download_url

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -281,6 +283,16 @@ def mask_func(self, shape, return_acs=False, seed=None):


class CalgaryCampinasMaskFunc(BaseMaskFunc):
BASE_URL = "https://s3.aiforoncology.nl/direct-project/calgary_campinas_masks/"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like https://s3.aiforoncology.nl/direct-project/ could be a global variable for direct given that we store everything here.

MASK_MD5S = {
"R10_218x170.npy": "6e1511c33dcfc4a960f526252676f7c3",
"R10_218x174.npy": "78fe23ae5eed2d3a8ff3ec128388dcc9",
"R10_218x180.npy": "5039a6c19ac2aa3472a94e4b015e5228",
"R5_218x170.npy": "6599715103cf3d71d6e87d09f865e7da",
"R5_218x174.npy": "5bd27d2da3bf1e78ad1b65c9b5e4b621",
"R5_218x180.npy": "717b51f3155c3a64cfaaddadbe90791d",
}

# TODO: Configuration improvements, so no **kwargs needed.
def __init__(self, accelerations: Tuple[int, ...], **kwargs): # noqa
super().__init__(accelerations=accelerations, uniform_range=False)
Expand Down Expand Up @@ -319,12 +331,17 @@ def mask_func(self, shape, return_acs=False):
return torch.from_numpy(mask[choice][np.newaxis, ..., np.newaxis])

def __load_masks(self, acceleration):
masks_path = pathlib.Path(pathlib.Path(__file__).resolve().parent / "calgary_campinas_masks")
masks_path = DIRECT_CACHE_DIR / "calgary_campinas_masks"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you download everything in cache memory. Does that mean that you have to download them everytime?

paths = [
f"R{acceleration}_218x170.npy",
f"R{acceleration}_218x174.npy",
f"R{acceleration}_218x180.npy",
]

downloaded = [download_url(self.BASE_URL + _, masks_path, md5=self.MASK_MD5S[_]) is None for _ in paths]
if not all(downloaded):
raise RuntimeError(f"Failed to download all Calgary-Campinas masks from {self.BASE_URL}.")

output = {}
for path in paths:
shape = [int(_) for _ in path.split("_")[-1][:-4].split("x")]
Expand Down
2 changes: 1 addition & 1 deletion direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.fft

from direct.data.bbox import crop_to_bbox
from direct.utils import ensure_list, is_power_of_two, is_complex_data
from direct.utils import ensure_list, is_complex_data, is_power_of_two
from direct.utils.asserts import assert_complex, assert_same_shape


Expand Down
34 changes: 23 additions & 11 deletions direct/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,24 @@
from torch.utils import collect_env

import direct.utils.logging
from direct.utils.logging import setup
from direct.config.defaults import DefaultConfig, InferenceConfig, TrainingConfig, ValidationConfig
from direct.utils import communication, count_parameters, str_to_class
from direct.utils.io import check_is_valid_url, read_text_from_url
from direct.utils.logging import setup

logger = logging.getLogger(__name__)

# Environmental variables
DIRECT_ROOT_DIR = pathlib.Path(pathlib.Path(__file__).resolve().parent.parent)
DIRECT_CACHE_DIR = pathlib.Path(os.environ.get("DIRECT_CACHE_DIR", str(DIRECT_ROOT_DIR)))
DIRECT_MODEL_DOWNLOAD_DIR = (
pathlib.Path(os.environ.get("DIRECT_MODEL_DOWNLOAD_DIR", str(DIRECT_ROOT_DIR))) / "downloaded_models"
)

Comment on lines +25 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are these stored?


def load_model_config_from_name(model_name):
"""
Load specific configuration module for
Load specific configuration module for models based on their name.
Parameters
----------
Expand Down Expand Up @@ -195,7 +203,7 @@ def extract_names(cfg):
def setup_common_environment(
run_name,
base_directory,
cfg_filename,
cfg_pathname,
device,
machine_rank,
mixed_precision,
Expand All @@ -213,12 +221,16 @@ def setup_common_environment(
communication.synchronize() # Ensure folders are in place.

# Load configs from YAML file to check which model needs to be loaded.
cfg_from_file = OmegaConf.load(cfg_filename)
# Can also be loaded from a URL
if check_is_valid_url(cfg_pathname):
cfg_from_external_source = OmegaConf.create(read_text_from_url(cfg_pathname))
else:
cfg_from_external_source = OmegaConf.load(cfg_pathname)

# Load the default configs to ensure type safety
cfg = OmegaConf.structured(DefaultConfig)

models, models_config = load_models_into_environment_config(cfg_from_file)
models, models_config = load_models_into_environment_config(cfg_from_external_source)
cfg.model = models_config.model
del models_config["model"]
cfg.additional_models = models_config
Expand All @@ -228,25 +240,25 @@ def setup_common_environment(
cfg.validation = ValidationConfig
cfg.inference = InferenceConfig

cfg_from_file_new = cfg_from_file.copy()
for key in cfg_from_file:
cfg_from_file_new = cfg_from_external_source.copy()
for key in cfg_from_external_source:
# TODO: This does not really do a full validation.
# BODY: This will be handeled once Hydra is implemented.
if key in ["models", "additional_models"]: # Still handled separately
continue

if key in ["training", "validation", "inference"]:
if not cfg_from_file[key]:
if not cfg_from_external_source[key]:
logger.info(f"key {key} missing in config.")
continue

if key in ["training", "validation"]:
dataset_cfg_from_file = extract_names(cfg_from_file[key].datasets)
dataset_cfg_from_file = extract_names(cfg_from_external_source[key].datasets)
for idx, (dataset_name, dataset_config) in enumerate(dataset_cfg_from_file):
cfg_from_file_new[key].datasets[idx] = dataset_config
cfg[key].datasets.append(load_dataset_config(dataset_name)) # pylint: disable = E1136
else:
dataset_name, dataset_config = extract_names(cfg_from_file[key].dataset)
dataset_name, dataset_config = extract_names(cfg_from_external_source[key].dataset)
cfg_from_file_new[key].dataset = dataset_config
cfg[key].dataset = load_dataset_config(dataset_name) # pylint: disable = E1136

Expand All @@ -255,7 +267,7 @@ def setup_common_environment(
# Make configuration read only.
# TODO(jt): Does not work when indexing config lists.
# OmegaConf.set_readonly(cfg, True)
setup_logging(machine_rank, experiment_dir, run_name, cfg_filename, cfg, debug)
setup_logging(machine_rank, experiment_dir, run_name, cfg_pathname, cfg, debug)
forward_operator, backward_operator = build_operators(cfg.physics)

model, additional_models = initialize_models_from_config(cfg, models, forward_operator, backward_operator, device)
Expand Down
2 changes: 1 addition & 1 deletion direct/nn/jointicnet/jointicnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn

import direct.data.transforms as T
from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d
from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d


class JointICNet(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions direct/nn/kikinet/kikinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

import direct.data.transforms as T
from direct.nn.conv.conv import Conv2d
from direct.nn.crossdomain.multicoil import MultiCoil
from direct.nn.didn.didn import DIDN
from direct.nn.mwcnn.mwcnn import MWCNN
from direct.nn.crossdomain.multicoil import MultiCoil
from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d
from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d


class KIKINet(nn.Module):
Expand Down
8 changes: 4 additions & 4 deletions direct/nn/lpd/lpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

from typing import Callable

import torch
import torch.nn as nn

import direct.data.transforms as T
from direct.nn.conv.conv import Conv2d
from direct.nn.didn.didn import DIDN
from direct.nn.mwcnn.mwcnn import MWCNN
from direct.nn.unet.unet_2d import UnetModel2d, NormUnetModel2d

import torch
import torch.nn as nn
from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d


class DualNet(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions direct/nn/multidomainnet/multidomainnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from typing import Callable

import direct.data.transforms as T
from direct.nn.multidomainnet.multidomain import MultiDomainUnet2d

import torch
import torch.nn as nn

import direct.data.transforms as T
from direct.nn.multidomainnet.multidomain import MultiDomainUnet2d


class StandardizationLayer(nn.Module):
"""
Expand Down
1 change: 1 addition & 0 deletions direct/nn/recurrent/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) DIRECT Contributors

from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down
7 changes: 1 addition & 6 deletions direct/nn/recurrentvarnet/recurrentvarnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F

from direct.data.transforms import (
conjugate,
complex_multiplication,
reduce_operator,
expand_operator,
)
from direct.data.transforms import complex_multiplication, conjugate, expand_operator, reduce_operator
from direct.nn.recurrent.recurrent import Conv2dGRU


Expand Down
2 changes: 1 addition & 1 deletion direct/nn/rim/rim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch.nn.functional as F

from direct.data import transforms as T
from direct.utils.asserts import assert_positive_integer
from direct.nn.recurrent.recurrent import Conv2dGRU
from direct.utils.asserts import assert_positive_integer


class MRILogLikelihood(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion direct/nn/unet/tests/test_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from direct.data.transforms import fft2, ifft2
from direct.nn.unet.unet_2d import Unet2d, NormUnetModel2d
from direct.nn.unet.unet_2d import NormUnetModel2d, Unet2d


def create_input(shape):
Expand Down
2 changes: 1 addition & 1 deletion direct/nn/xpdnet/xpdnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import torch.nn as nn

from direct.nn.conv.conv import Conv2d
from direct.nn.crossdomain.crossdomain import CrossDomainNetwork
from direct.nn.crossdomain.multicoil import MultiCoil
from direct.nn.conv.conv import Conv2d
from direct.nn.didn.didn import DIDN
from direct.nn.mwcnn.mwcnn import MWCNN

Expand Down
2 changes: 2 additions & 0 deletions direct/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
Loading