Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable to download checkpoint from url for inference #127

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions direct/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def load(
iteration: Union[int, str, None],
checkpointable_objects: Optional[Dict[str, nn.Module]] = None,
) -> Dict:
if iteration is not None and not isinstance(iteration, int) and iteration != "latest":
raise ValueError("Value `iteration` is expected to be either None, an integer or `latest`.")
if iteration is not None and not isinstance(iteration, int) and iteration not in ["latest", "download"]:
raise ValueError("Value `iteration` is expected to be either None, an integer, `latest` or 'download'.")

if iteration is None:
return {}
Expand Down
6 changes: 6 additions & 0 deletions direct/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ class PhysicsConfig(BaseConfig):
noise_matrix_scaling: Optional[float] = 1.0


@dataclass
class CheckpointConfig(BaseConfig):
checkpoint_url: Optional[str] = None


@dataclass
class DefaultConfig(BaseConfig):
model: ModelConfig = MISSING
Expand All @@ -118,5 +123,6 @@ class DefaultConfig(BaseConfig):
training: TrainingConfig = TrainingConfig() # This should be optional.
validation: ValidationConfig = ValidationConfig() # This should be optional.
inference: Optional[InferenceConfig] = None
checkpoint: CheckpointConfig = CheckpointConfig()

logging: LoggingConfig = LoggingConfig()
34 changes: 34 additions & 0 deletions direct/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,33 @@ def setup_inference_save_to_h5(
"""
env = setup_inference_environment(run_name, base_directory, device, machine_rank, mixed_precision, debug=debug)

if env.cfg.checkpoint.checkpoint_url is not None:
try:
if checkpoint is not None:
logger.warning(f"'checkpoint_url' is not null. This will ignore checkpoint value {checkpoint}.")

checkpoint = "download"
logger.info(
f"Attempting to download checkpoint from {env.cfg.checkpoint.checkpoint_url} "
f"as {f'model_{checkpoint}.pt'}."
)
torch.hub.load_state_dict_from_url(
env.cfg.checkpoint.checkpoint_url,
model_dir=base_directory,
file_name=f"model_{checkpoint}.pt",
progress=False,
)
logger.info(
f"Successfully downloaded checkpoint from {env.cfg.checkpoint.checkpoint_url}. "
f"Saved temporarily to {base_directory}."
)
except Exception as exc:
logger.info(
f"Could not download checkpoint from {env.cfg.checkpoint.checkpoint_url}. Make sure that"
f"the url contains a valid torch state_dict. Exiting with error message: {exc}"
)
sys.exit(-1)

dataset_cfg, transforms = get_inference_settings(env)

# Trigger cudnn benchmark when the number of different input masks_dict is small.
Expand All @@ -85,6 +112,13 @@ def setup_inference_save_to_h5(
filenames_filter=curr_filenames_filter,
)

if env.cfg.checkpoint.checkpoint_url is not None:
import os

# Delete downloaded checkpoint.
logger.info(f"Removing {f'model_{checkpoint}.pt'} from {base_directory}...")
os.remove(base_directory / f"model_{checkpoint}.pt")

# Perhaps aggregation to the main process would be most optimal here before writing.
# The current way this write the volumes for each process.
write_output_to_h5(
Expand Down
5 changes: 4 additions & 1 deletion projects/calgary_campinas/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def _get_transforms(masks_dict, env):
Run on multiple machines:
(machine0)$ {sys.argv[0]} data_root output_directory --checkpoint <checkpoint_num> --name <name> --masks <path to masks> --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
(machine1)$ {sys.argv[0]} data_root output_directory --checkpoint <checkpoint_num> --name <name> --masks <path to masks> --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
Download checkpoint from url (checkpoint_url must be specified in config):
$ {sys.argv[0]} data_root output_directory --name <name> [--other-flags]
If "--checkpoint <checkpoint_num>" is passed it will be ignored.
"""

parser = Args(epilog=epilog)
Expand All @@ -100,7 +103,7 @@ def _get_transforms(masks_dict, env):
parser.add_argument(
"--checkpoint",
type=int,
required=True,
required=False,
help="Number of an existing checkpoint.",
)
parser.add_argument(
Expand Down
7 changes: 5 additions & 2 deletions projects/calgary_campinas/predict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ def _get_transforms(validation_index, env):
Run on multiple machines:
(machine0)$ {sys.argv[0]} data_root output_directory --checkpoint <checkpoint_num> --name <name> --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
(machine1)$ {sys.argv[0]} data_root output_directory --checkpoint <checkpoint_num> --name <name> --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
Download checkpoint from url (checkpoint_url must be specified in config):
$ {sys.argv[0]} data_root output_directory --name <name> [--other-flags]
If "--checkpoint <checkpoint_num>" is passed it will be ignored.
"""

parser = Args(epilog=epilog)
parser.add_argument("data_root", type=pathlib.Path, help="Path to the DoIterationOutput directory.")
parser.add_argument("output_directory", type=pathlib.Path, help="Path to the DoIterationOutput directory.")
parser.add_argument("output_directory", type=pathlib.Path, help="Path to the output directory.")
parser.add_argument(
"experiment_directory",
type=pathlib.Path,
Expand All @@ -55,7 +58,7 @@ def _get_transforms(validation_index, env):
parser.add_argument(
"--checkpoint",
type=int,
required=True,
required=False,
help="Number of an existing checkpoint.",
)
parser.add_argument(
Expand Down