Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download checkpoints #133

Merged
merged 33 commits into from
Dec 10, 2021
Merged

Download checkpoints #133

merged 33 commits into from
Dec 10, 2021

Conversation

jonasteuwen
Copy link
Contributor

@jonasteuwen jonasteuwen commented Nov 29, 2021

This is a draft for a pull request to download models from URLs. We need some way to also initialize a model easily. Right now loading from the checkpoint as an URL should work.

Needs to check

  • Calgary-Campinas masks need to download too, and so need to be uploaded?
  • Add a test for the downloads
  • Does initializing the model from the URL work properly now?

jonasteuwen and others added 3 commits November 17, 2021 18:01
* Update engine.py

Normalize target for visualization.

* Update README.md
* Create recurrentvarnet.py
* Create config.py
* Create __init__.py
* Update recurrentvarnet.py
* Create recurrentvarnet_engine.py
* Update recurrentvarnet.py
* Update recurrentvarnet.py
* Update recurrentvarnet_engine.py
* Update documentation
* Create base_recurrentvarnet.yaml
* Update README.md
* Update model_zoo.md
* Option to choose between max/min
* Update parse_metrics_log.py
Co-authored-by: George Yiasemis <[email protected]>
@jonasteuwen jonasteuwen marked this pull request as draft November 29, 2021 20:53
Comment on lines 27 to 38
print([i.name for i in args.metrics_path.glob("*.pt")])
with open(args.metrics_path / "metrics.json", "r") as f:
data = f.readlines()
data = [json.loads(_) for _ in data]

x = np.asarray([(int(_["iteration"]), -_[args.key]) for _ in data if args.key in _])
out = x[np.where(x[:, 1] == x[:, 1].max())][0]

x = np.asarray([(int(_["iteration"]), _[args.key]) for _ in data if args.key in _])
if args.max:
out = x[np.where(x[:, 1] == x[:, 1].max())][0]
else:
out = x[np.where(x[:, 1] == x[:, 1].min())][0]
print(f"{args.key} - {int(out[0])}: {out[1]}")
print(x[np.where(x[:, 1] == 148520)][0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to rebase again from main

@@ -281,6 +283,7 @@ def mask_func(self, shape, return_acs=False, seed=None):


class CalgaryCampinasMaskFunc(BaseMaskFunc):
BASE_URL = "https://s3.aiforoncology.nl/direct-project/calgary_campinas_masks/"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like https://s3.aiforoncology.nl/direct-project/ could be a global variable for direct given that we store everything here.

@@ -319,12 +322,16 @@ def mask_func(self, shape, return_acs=False):
return torch.from_numpy(mask[choice][np.newaxis, ..., np.newaxis])

def __load_masks(self, acceleration):
masks_path = pathlib.Path(pathlib.Path(__file__).resolve().parent / "calgary_campinas_masks")
masks_path = DIRECT_CACHE_DIR / "calgary_campinas_masks"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you download everything in cache memory. Does that mean that you have to download them everytime?

Comment on lines +24 to +30
# Environmental variables
DIRECT_ROOT_DIR = pathlib.Path(pathlib.Path(__file__).resolve().parent.parent)
DIRECT_CACHE_DIR = pathlib.Path(os.environ.get("DIRECT_CACHE_DIR", str(DIRECT_ROOT_DIR)))
DIRECT_MODEL_DOWNLOAD_DIR = (
pathlib.Path(os.environ.get("DIRECT_MODEL_DOWNLOAD_DIR", str(DIRECT_ROOT_DIR))) / "downloaded_models"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are these stored?

@github-actions github-actions bot added the python label Dec 1, 2021
Jonas Teuwen and others added 6 commits December 10, 2021 12:59
- Download Calgary Campinas masks when needed, remove from repo
- Allow parsing of config from url
- Allow loading checkpoint from url
- Download Calgary Campinas masks when needed, remove from repo
- Allow parsing of config from url
- Allow loading checkpoint from url
- Download Calgary Campinas masks when needed, remove from repo
- Allow parsing of config from url
- Allow loading checkpoint from url
@jonasteuwen jonasteuwen marked this pull request as ready for review December 10, 2021 15:47
@jonasteuwen
Copy link
Contributor Author

Failure of the tox test is slightly random. Can you review @georgeyiasemis @jonatanferm

Comment on lines +301 to 304
type=file_or_url,
help="If this value is set to a proper checkpoint when training starts, "
"the model will be initialized with the weights given. "
"No other keys in the checkpoint will be loaded. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a bit more elaborate description about loading from directory or downloading from a url. (Maybe in the README?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have adjusted it.

Copy link
Contributor

@georgeyiasemis georgeyiasemis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general. We maybe need to add some instructions on how to train a model or run inference with the new changes.

@jonasteuwen jonasteuwen merged commit 2de00ed into main Dec 10, 2021
@jonasteuwen jonasteuwen deleted the download-checkpoints branch December 10, 2021 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants