Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Take a separate argument for loading checkpoints, take paths for pretrained checkpoints #379

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions classy_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,13 @@ def main(args, config):

task = build_task(config)

# Load checkpoint, if available. This automatically resumes from an
# existing checkpoint, in case training is being restarted.
checkpoint = load_checkpoint(args.checkpoint_folder)
# Load checkpoint, if available.
checkpoint = load_checkpoint(args.checkpoint_load_path)
task.set_checkpoint(checkpoint)

# Load a checkpoint contraining a pre-trained model. This is how we
# implement fine-tuning of existing models.
pretrained_checkpoint = load_checkpoint(args.pretrained_checkpoint_folder)
pretrained_checkpoint = load_checkpoint(args.pretrained_checkpoint_path)
if pretrained_checkpoint is not None:
assert isinstance(
task, FineTuningTask
Expand Down
22 changes: 15 additions & 7 deletions classy_vision/generic/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,27 @@ def add_generic_args(parser):
"--checkpoint_folder",
default="",
type=str,
help="""folder to use for checkpoints:
help="""folder to use for saving checkpoints:
epochal checkpoints are stored as model_<epoch>.torch,
latest epoch checkpoint is at checkpoint.torch""",
)
parser.add_argument(
"--pretrained_checkpoint_folder",
"--checkpoint_load_path",
default="",
type=str,
help="""folder to use for pre-trained checkpoints:
epochal checkpoints are stored as model_<epoch>.torch,
latest epoch checkpoint is at checkpoint.torch,
checkpoint is used for fine-tuning task, and it will
not resume training from the checkpoint""",
help="""path to load a checkpoint from, which can be a file or a directory:
If the path is a directory, the checkpoint file is assumed to be
checkpoint.torch""",
)
parser.add_argument(
"--pretrained_checkpoint_path",
default="",
type=str,
help="""path to load a pre-trained checkpoints from, which can be a file or a
directory:
If the path is a directory, the checkpoint file is assumed to be
checkpoint.torch. This checkpoint is only used for fine-tuning
tasks, and training will not resume from this checkpoint.""",
)
parser.add_argument(
"--checkpoint_period",
Expand Down
10 changes: 5 additions & 5 deletions classy_vision/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ def __init__(
checkpoint_folder: Folder to store checkpoints in
input_args: Any arguments to save about the runtime setup. For example,
it is useful to store the config that was used to instantiate the model.
phase_types: If ``phase_types`` is specified, only checkpoint on those phase
types. Each item in ``phase_types`` must be either "train" or "test".
phase_types: If `phase_types` is specified, only checkpoint on those phase
types. Each item in `phase_types` must be either "train" or "test". If
not specified, it is set to checkpoint after "train" phases.
checkpoint_period: Checkpoint at the end of every x phases (default 1)

"""
super().__init__()
self.checkpoint_folder: str = checkpoint_folder
self.input_args: Any = input_args
if phase_types is None:
phase_types = ["train", "test"]
phase_types = ["train"]
assert len(phase_types) > 0 and all(
phase_type in ["train", "test"] for phase_type in phase_types
), "phase_types should contain one or more of ['train', 'test']"
Expand Down Expand Up @@ -81,7 +81,7 @@ def _save_checkpoint(self, task, filename):
def on_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
if getattr(task, "test_only", False):
if not is_master() or getattr(task, "test_only", False):
return
if not PathManager.exists(self.checkpoint_folder):
err_msg = "Checkpoint folder '{}' does not exist.".format(
Expand Down
3 changes: 2 additions & 1 deletion classy_vision/templates/synthetic/hydra_configs/args.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ checkpoint_folder: ""
checkpoint_period: 1
log_freq: 5
num_workers: 4
pretrained_checkpoint_folder: ""
checkpoint_load_path: ""
pretrained_checkpoint_path: ""
profiler: False
skip_tensorboard: False
show_progress: False
Expand Down
12 changes: 5 additions & 7 deletions tutorials/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@
"source": [
"## 4. Loading checkpoints\n",
"\n",
"Now that we've run `classy_train.py`, let's see how to load the resulting model. At the end of execution, `classy_train.py` will print the checkpoint directory used for that run. Each run will output to a different directory, typically named `output_<timestamp>/checkpoints`."
"Now that we've run `classy_train.py`, let's see how to load the resulting model. At the end of execution, `classy_train.py` will print the checkpoint directory used for that run. Each run will output to a different directory, typically named `output_<timestamp>/checkpoints`. This can be configured by passing the `--checkpoint_folder` argument to `classy_train.py`"
]
},
{
Expand Down Expand Up @@ -269,18 +269,16 @@
"\n",
"## 5. Resuming from checkpoints\n",
"\n",
"Resuming from a checkpoint is as simple as training: `classy_train.py` takes a `--checkpoint_folder` argument, which specifies the checkpoint to resume from:"
"Resuming from a checkpoint is as simple as training: `classy_train.py` takes a `--checkpoint_load_path` argument, which specifies the checkpoint path to resume from:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"! ./classy_train.py --config configs/template_config.json --checkpoint_folder ./output_<timestamp>/checkpoints"
"! ./classy_train.py --config configs/template_config.json --checkpoint_load_path ./output_<timestamp>/checkpoints"
]
},
{
Expand Down Expand Up @@ -554,7 +552,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5+"
"version": "3.7.3"
}
},
"nbformat": 4,
Expand Down