-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Download checkpoints #133
Merged
Merged
Download checkpoints #133
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
ff4f353
Allow URLs as input to checkpointer
jonasteuwen 4be53ab
Merge upstream to download-checkpoints (#132)
jonasteuwen 08dbeee
Added utilities to download models
jonasteuwen 8618298
Preliminary fix for calgary models
jonasteuwen d8b42b6
Merge branch 'main' into download-checkpoints
georgeyiasemis 0eb2a21
Added possibility to load config from URL
jonasteuwen 5e5dcb4
move checkpointer io test to test_io
jonasteuwen 8d49db6
move checkpointer io test to test_io
jonasteuwen c82ee52
Version bump for pytorch
8494ba6
Smoothing changes
jonasteuwen 26a3e33
Allow read_list to read data from url
jonasteuwen cd32264
Allow to read lists from URL as well
jonasteuwen 27c70b6
Small change
jonasteuwen 9d5b0ea
Formatting and import fixes
jonasteuwen 6e637a7
Formatting and import fixes
jonasteuwen be88e74
download_url already checks if all is okay
jonasteuwen fbd6ece
Add md5s to download
jonasteuwen 2eca066
remove superflucuous files
jonasteuwen 99c93bf
Allow checkpoints and configs from remote url
jonasteuwen cfa8d0a
Fix conflicts
afcc7f4
Allow checkpoints and configs from remote url
jonasteuwen 43f861a
Allow checkpoints and configs from remote url
jonasteuwen 97c8876
Imrpoved http error
eae29e5
Allow checkpoints and configs from remote url
jonasteuwen bd385c5
Merge branch 'download-checkpoints' of github.com:directgroup/direct …
jonasteuwen bdec58b
Checkpoint can be URL
jonasteuwen 7ce3d08
dir->file
jonasteuwen 9812548
dir->file
jonasteuwen bd926ac
Fix torch setup.py version
jonasteuwen f12f58f
Fix torch setup.py version
jonasteuwen 85659aa
Placeholder
jonasteuwen 085589f
Fix black
jonasteuwen f5326dd
Extend the help
jonasteuwen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# coding=utf-8 | ||
# Copyright (c) DIRECT Contributors | ||
import argparse | ||
import pathlib | ||
|
||
from direct.utils.io import check_is_valid_url | ||
|
||
|
||
def file_or_url(path): | ||
if check_is_valid_url(path): | ||
return path | ||
path = pathlib.Path(path) | ||
if path.is_file(): | ||
return path | ||
raise argparse.ArgumentTypeError(f"{path} is not a valid file or url.") |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,10 @@ | |
import numpy as np | ||
import torch | ||
|
||
from direct.environment import DIRECT_CACHE_DIR | ||
from direct.types import Number | ||
from direct.utils import str_to_class | ||
from direct.utils.io import download_url | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -281,6 +283,16 @@ def mask_func(self, shape, return_acs=False, seed=None): | |
|
||
|
||
class CalgaryCampinasMaskFunc(BaseMaskFunc): | ||
BASE_URL = "https://s3.aiforoncology.nl/direct-project/calgary_campinas_masks/" | ||
MASK_MD5S = { | ||
"R10_218x170.npy": "6e1511c33dcfc4a960f526252676f7c3", | ||
"R10_218x174.npy": "78fe23ae5eed2d3a8ff3ec128388dcc9", | ||
"R10_218x180.npy": "5039a6c19ac2aa3472a94e4b015e5228", | ||
"R5_218x170.npy": "6599715103cf3d71d6e87d09f865e7da", | ||
"R5_218x174.npy": "5bd27d2da3bf1e78ad1b65c9b5e4b621", | ||
"R5_218x180.npy": "717b51f3155c3a64cfaaddadbe90791d", | ||
} | ||
|
||
# TODO: Configuration improvements, so no **kwargs needed. | ||
def __init__(self, accelerations: Tuple[int, ...], **kwargs): # noqa | ||
super().__init__(accelerations=accelerations, uniform_range=False) | ||
|
@@ -319,12 +331,17 @@ def mask_func(self, shape, return_acs=False): | |
return torch.from_numpy(mask[choice][np.newaxis, ..., np.newaxis]) | ||
|
||
def __load_masks(self, acceleration): | ||
masks_path = pathlib.Path(pathlib.Path(__file__).resolve().parent / "calgary_campinas_masks") | ||
masks_path = DIRECT_CACHE_DIR / "calgary_campinas_masks" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you download everything in cache memory. Does that mean that you have to download them everytime? |
||
paths = [ | ||
f"R{acceleration}_218x170.npy", | ||
f"R{acceleration}_218x174.npy", | ||
f"R{acceleration}_218x180.npy", | ||
] | ||
|
||
downloaded = [download_url(self.BASE_URL + _, masks_path, md5=self.MASK_MD5S[_]) is None for _ in paths] | ||
if not all(downloaded): | ||
raise RuntimeError(f"Failed to download all Calgary-Campinas masks from {self.BASE_URL}.") | ||
|
||
output = {} | ||
for path in paths: | ||
shape = [int(_) for _ in path.split("_")[-1][:-4].split("x")] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,16 +15,24 @@ | |
from torch.utils import collect_env | ||
|
||
import direct.utils.logging | ||
from direct.utils.logging import setup | ||
from direct.config.defaults import DefaultConfig, InferenceConfig, TrainingConfig, ValidationConfig | ||
from direct.utils import communication, count_parameters, str_to_class | ||
from direct.utils.io import check_is_valid_url, read_text_from_url | ||
from direct.utils.logging import setup | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Environmental variables | ||
DIRECT_ROOT_DIR = pathlib.Path(pathlib.Path(__file__).resolve().parent.parent) | ||
DIRECT_CACHE_DIR = pathlib.Path(os.environ.get("DIRECT_CACHE_DIR", str(DIRECT_ROOT_DIR))) | ||
DIRECT_MODEL_DOWNLOAD_DIR = ( | ||
pathlib.Path(os.environ.get("DIRECT_MODEL_DOWNLOAD_DIR", str(DIRECT_ROOT_DIR))) / "downloaded_models" | ||
) | ||
|
||
Comment on lines
+25
to
+31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where are these stored? |
||
|
||
def load_model_config_from_name(model_name): | ||
""" | ||
Load specific configuration module for | ||
Load specific configuration module for models based on their name. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -195,7 +203,7 @@ def extract_names(cfg): | |
def setup_common_environment( | ||
run_name, | ||
base_directory, | ||
cfg_filename, | ||
cfg_pathname, | ||
device, | ||
machine_rank, | ||
mixed_precision, | ||
|
@@ -213,12 +221,16 @@ def setup_common_environment( | |
communication.synchronize() # Ensure folders are in place. | ||
|
||
# Load configs from YAML file to check which model needs to be loaded. | ||
cfg_from_file = OmegaConf.load(cfg_filename) | ||
# Can also be loaded from a URL | ||
if check_is_valid_url(cfg_pathname): | ||
cfg_from_external_source = OmegaConf.create(read_text_from_url(cfg_pathname)) | ||
else: | ||
cfg_from_external_source = OmegaConf.load(cfg_pathname) | ||
|
||
# Load the default configs to ensure type safety | ||
cfg = OmegaConf.structured(DefaultConfig) | ||
|
||
models, models_config = load_models_into_environment_config(cfg_from_file) | ||
models, models_config = load_models_into_environment_config(cfg_from_external_source) | ||
cfg.model = models_config.model | ||
del models_config["model"] | ||
cfg.additional_models = models_config | ||
|
@@ -228,25 +240,25 @@ def setup_common_environment( | |
cfg.validation = ValidationConfig | ||
cfg.inference = InferenceConfig | ||
|
||
cfg_from_file_new = cfg_from_file.copy() | ||
for key in cfg_from_file: | ||
cfg_from_file_new = cfg_from_external_source.copy() | ||
for key in cfg_from_external_source: | ||
# TODO: This does not really do a full validation. | ||
# BODY: This will be handeled once Hydra is implemented. | ||
if key in ["models", "additional_models"]: # Still handled separately | ||
continue | ||
|
||
if key in ["training", "validation", "inference"]: | ||
if not cfg_from_file[key]: | ||
if not cfg_from_external_source[key]: | ||
logger.info(f"key {key} missing in config.") | ||
continue | ||
|
||
if key in ["training", "validation"]: | ||
dataset_cfg_from_file = extract_names(cfg_from_file[key].datasets) | ||
dataset_cfg_from_file = extract_names(cfg_from_external_source[key].datasets) | ||
for idx, (dataset_name, dataset_config) in enumerate(dataset_cfg_from_file): | ||
cfg_from_file_new[key].datasets[idx] = dataset_config | ||
cfg[key].datasets.append(load_dataset_config(dataset_name)) # pylint: disable = E1136 | ||
else: | ||
dataset_name, dataset_config = extract_names(cfg_from_file[key].dataset) | ||
dataset_name, dataset_config = extract_names(cfg_from_external_source[key].dataset) | ||
cfg_from_file_new[key].dataset = dataset_config | ||
cfg[key].dataset = load_dataset_config(dataset_name) # pylint: disable = E1136 | ||
|
||
|
@@ -255,7 +267,7 @@ def setup_common_environment( | |
# Make configuration read only. | ||
# TODO(jt): Does not work when indexing config lists. | ||
# OmegaConf.set_readonly(cfg, True) | ||
setup_logging(machine_rank, experiment_dir, run_name, cfg_filename, cfg, debug) | ||
setup_logging(machine_rank, experiment_dir, run_name, cfg_pathname, cfg, debug) | ||
forward_operator, backward_operator = build_operators(cfg.physics) | ||
|
||
model, additional_models = initialize_models_from_config(cfg, models, forward_operator, backward_operator, device) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# coding=utf-8 | ||
# Copyright (c) DIRECT Contributors |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like
https://s3.aiforoncology.nl/direct-project/
could be a global variable for direct given that we store everything here.