-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Download checkpoints #133
Download checkpoints #133
Conversation
* Update engine.py Normalize target for visualization. * Update README.md * Create recurrentvarnet.py * Create config.py * Create __init__.py * Update recurrentvarnet.py * Create recurrentvarnet_engine.py * Update recurrentvarnet.py * Update recurrentvarnet.py * Update recurrentvarnet_engine.py * Update documentation * Create base_recurrentvarnet.yaml * Update README.md * Update model_zoo.md * Option to choose between max/min * Update parse_metrics_log.py Co-authored-by: George Yiasemis <[email protected]>
- formatting updates
tools/parse_metrics_log.py
Outdated
print([i.name for i in args.metrics_path.glob("*.pt")]) | ||
with open(args.metrics_path / "metrics.json", "r") as f: | ||
data = f.readlines() | ||
data = [json.loads(_) for _ in data] | ||
|
||
x = np.asarray([(int(_["iteration"]), -_[args.key]) for _ in data if args.key in _]) | ||
out = x[np.where(x[:, 1] == x[:, 1].max())][0] | ||
|
||
x = np.asarray([(int(_["iteration"]), _[args.key]) for _ in data if args.key in _]) | ||
if args.max: | ||
out = x[np.where(x[:, 1] == x[:, 1].max())][0] | ||
else: | ||
out = x[np.where(x[:, 1] == x[:, 1].min())][0] | ||
print(f"{args.key} - {int(out[0])}: {out[1]}") | ||
print(x[np.where(x[:, 1] == 148520)][0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to rebase again from main
@@ -281,6 +283,7 @@ def mask_func(self, shape, return_acs=False, seed=None): | |||
|
|||
|
|||
class CalgaryCampinasMaskFunc(BaseMaskFunc): | |||
BASE_URL = "https://s3.aiforoncology.nl/direct-project/calgary_campinas_masks/" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like https://s3.aiforoncology.nl/direct-project/
could be a global variable for direct given that we store everything here.
@@ -319,12 +322,16 @@ def mask_func(self, shape, return_acs=False): | |||
return torch.from_numpy(mask[choice][np.newaxis, ..., np.newaxis]) | |||
|
|||
def __load_masks(self, acceleration): | |||
masks_path = pathlib.Path(pathlib.Path(__file__).resolve().parent / "calgary_campinas_masks") | |||
masks_path = DIRECT_CACHE_DIR / "calgary_campinas_masks" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you download everything in cache memory. Does that mean that you have to download them everytime?
# Environmental variables | ||
DIRECT_ROOT_DIR = pathlib.Path(pathlib.Path(__file__).resolve().parent.parent) | ||
DIRECT_CACHE_DIR = pathlib.Path(os.environ.get("DIRECT_CACHE_DIR", str(DIRECT_ROOT_DIR))) | ||
DIRECT_MODEL_DOWNLOAD_DIR = ( | ||
pathlib.Path(os.environ.get("DIRECT_MODEL_DOWNLOAD_DIR", str(DIRECT_ROOT_DIR))) / "downloaded_models" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where are these stored?
- Fix tests - Allow path or URL in cfg settings
- Download Calgary Campinas masks when needed, remove from repo - Allow parsing of config from url - Allow loading checkpoint from url
bbdc4ca
to
99c93bf
Compare
- Download Calgary Campinas masks when needed, remove from repo - Allow parsing of config from url - Allow loading checkpoint from url
- Download Calgary Campinas masks when needed, remove from repo - Allow parsing of config from url - Allow loading checkpoint from url
- Download Calgary Campinas masks when needed, remove from repo - Allow parsing of config from url - Allow loading checkpoint from url
…into download-checkpoints
Failure of the tox test is slightly random. Can you review @georgeyiasemis @jonatanferm |
type=file_or_url, | ||
help="If this value is set to a proper checkpoint when training starts, " | ||
"the model will be initialized with the weights given. " | ||
"No other keys in the checkpoint will be loaded. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a bit more elaborate description about loading from directory or downloading from a url. (Maybe in the README?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have adjusted it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in general. We maybe need to add some instructions on how to train a model or run inference with the new changes.
This is a draft for a pull request to download models from URLs. We need some way to also initialize a model easily. Right now loading from the checkpoint as an URL should work.
Needs to check