From fe43f3a65b31e5e811a7f400ec96fad05ee3a2eb Mon Sep 17 00:00:00 2001 From: Jonas Teuwen <2347927+jonasteuwen@users.noreply.github.com> Date: Mon, 13 Dec 2021 16:45:40 +0100 Subject: [PATCH] Revert "Design `direct train` CLI interface (#136)" (#138) This reverts commit 65089224d0f15b2da2bec62769b32909c7785a0b. --- direct/cli/__init__.py | 15 ++---- direct/cli/train.py | 67 ------------------------- direct/cli/utils.py | 6 --- direct/environment.py | 4 +- tools/README.md | 7 +++ direct/train.py => tools/train_model.py | 59 ++++++++++++++++++++-- 6 files changed, 69 insertions(+), 89 deletions(-) delete mode 100644 direct/cli/train.py rename direct/train.py => tools/train_model.py (79%) diff --git a/direct/cli/__init__.py b/direct/cli/__init__.py index 36518f18f..7354c1cbb 100644 --- a/direct/cli/__init__.py +++ b/direct/cli/__init__.py @@ -1,30 +1,25 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -"""DIRECT Command-line interface. This is the file which builds the main parser.""" +"""DIRECT Command-line interface. This is the file which builds the main parser. Currently just a placeholder""" import argparse import sys def main(): """ - Console script for direct. + Console script for dlup. """ # From https://stackoverflow.com/questions/17073688/how-to-use-argparse-subparsers-correctly root_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - root_subparsers = root_parser.add_subparsers(help="Direct CLI utils to run.") + root_subparsers = root_parser.add_subparsers(help="DIRECT utilities.") root_subparsers.required = True root_subparsers.dest = "subcommand" - # Prevent circular import - from direct.cli.train import register_parser as register_train_subcommand - - # Training images related commands. - register_train_subcommand(root_subparsers) - args = root_parser.parse_args() args.subcommand(args) + return 0 if __name__ == "__main__": - main() + sys.exit(main()) # pragma: no cover diff --git a/direct/cli/train.py b/direct/cli/train.py deleted file mode 100644 index 8e34966e0..000000000 --- a/direct/cli/train.py +++ /dev/null @@ -1,67 +0,0 @@ -# coding=utf-8 -# Copyright (c) DIRECT Contributors -import argparse -import pathlib - -from direct.cli.utils import file_or_url -from direct.environment import Args -from direct.train import train_from_argparse - - -def register_parser(parser: argparse._SubParsersAction): - """Register wsi commands to a root parser.""" - - epilog = f""" - Examples: - --------- - Run on single machine: - $ direct train training_set validation_set experiment_dir --num-gpus 8 --cfg cfg.yaml - Run on multiple machines: - (machine0)$ direct train training_set validation_set experiment_dir --machine-rank 0 --num-machines 2 --dist-url [--other-flags] - (machine1)$ direct train training_set validation_set experiment_dir --machine-rank 1 --num-machines 2 --dist-url [--other-flags] - """ - common_parser = Args(add_help=False) - train_parser = parser.add_parser( - "train", - help="Train models using direct.", - parents=[common_parser], - epilog=epilog, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - train_parser.add_argument("training_root", type=pathlib.Path, help="Path to the training data.") - train_parser.add_argument("validation_root", type=pathlib.Path, help="Path to the validation data.") - train_parser.add_argument( - "experiment_dir", - type=pathlib.Path, - help="Path to the experiment directory.", - ) - train_parser.add_argument( - "--cfg", - dest="cfg_file", - help="Config file for training. Can be either a local file or a remote URL.", - required=True, - type=file_or_url, - ) - train_parser.add_argument( - "--initialization-checkpoint", - type=file_or_url, - help="If this value is set to a proper checkpoint when training starts, " - "the model will be initialized with the weights given. " - "No other keys in the checkpoint will be loaded. " - "When another checkpoint would be available and the --resume flag is used, " - "this flag is ignored. This can be a path to a file or an URL. " - "If a URL is given the checkpoint will first be downloaded to the environmental variable " - "`DIRECT_MODEL_DOWNLOAD_DIR` (default=current directory).", - ) - train_parser.add_argument("--resume", help="Resume training if possible.", action="store_true") - train_parser.add_argument( - "--force-validation", - help="Start with a validation round, when recovering from a crash. " - "If you use this option, be aware that when combined with --resume, " - "each new run will start with a validation round.", - action="store_true", - ) - train_parser.add_argument("--name", help="Run name.", required=False, type=str) - - train_parser.set_defaults(subcommand=train_from_argparse) diff --git a/direct/cli/utils.py b/direct/cli/utils.py index 07a187ef5..acadea905 100644 --- a/direct/cli/utils.py +++ b/direct/cli/utils.py @@ -2,7 +2,6 @@ # Copyright (c) DIRECT Contributors import argparse import pathlib -import sys from direct.utils.io import check_is_valid_url @@ -14,8 +13,3 @@ def file_or_url(path): if path.is_file(): return path raise argparse.ArgumentTypeError(f"{path} is not a valid file or url.") - - -def check_train_val(key, name): - if key is not None and len(key) != 2: - sys.exit(f"--{name} has to be of the form `train_folder, validation_folder` if a validation folder is set.") diff --git a/direct/environment.py b/direct/environment.py index 57ef62271..28b485686 100644 --- a/direct/environment.py +++ b/direct/environment.py @@ -379,12 +379,12 @@ class Args(argparse.ArgumentParser): Defines global default arguments. """ - def __init__(self, epilog=None, add_help=True, **overrides): + def __init__(self, epilog=None, **overrides): """ Args: **overrides (dict, optional): Keyword arguments used to override default argument values """ - super().__init__(epilog=epilog, formatter_class=argparse.RawDescriptionHelpFormatter, add_help=add_help) + super().__init__(epilog=epilog, formatter_class=argparse.RawDescriptionHelpFormatter) self.add_argument( "--device", diff --git a/tools/README.md b/tools/README.md index 6fffa56cb..312efd456 100644 --- a/tools/README.md +++ b/tools/README.md @@ -1,4 +1,11 @@ # Direct tools Scripts are provided: +- To train a model use `train_model.py`. - To extract the best checkpoint based on `metrics.json`, use `parse_metrics_log.py`. + + +## Tips and tricks + +- We are using a lot of experimental features in pytorch, to reduce such warnings you can use + `export PYTHONWARNINGS="ignore"` in the shell before execution. diff --git a/direct/train.py b/tools/train_model.py similarity index 79% rename from direct/train.py rename to tools/train_model.py index 9b3d2f5be..22a84f149 100644 --- a/direct/train.py +++ b/tools/train_model.py @@ -1,6 +1,5 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors -import argparse import functools import logging import os @@ -13,12 +12,12 @@ import numpy as np import torch -from direct.cli.utils import check_train_val +from direct.cli.utils import file_or_url from direct.common.subsample import build_masking_function from direct.data.datasets import build_dataset_from_input from direct.data.lr_scheduler import WarmupMultiStepLR from direct.data.mri_transforms import build_mri_transforms -from direct.environment import setup_training_environment +from direct.environment import Args, setup_training_environment from direct.launch import launch from direct.utils import remove_keys, set_all_seeds, str_to_class from direct.utils.dataset import get_filenames_for_datasets @@ -259,7 +258,12 @@ def setup_train( ) -def train_from_argparse(args: argparse.Namespace): +def check_train_val(key, name): + if key is not None and len(key) != 2: + sys.exit(f"--{name} has to be of the form `train_folder, validation_folder` if a validation folder is set.") + + +if __name__ == "__main__": # This sets MKL threads to 1. # DataLoader can otherwise bring a l ot of difficulties when computing CPU FFTs in the transforms. torch.set_num_threads(1) @@ -268,6 +272,53 @@ def train_from_argparse(args: argparse.Namespace): # Remove warnings from named tensors being experimental os.environ["PYTHONWARNINGS"] = "ignore" + epilog = f""" + Examples: + Run on single machine: + $ {sys.argv[0]} training_set validation_set experiment_dir --num-gpus 8 --cfg cfg.yaml + Run on multiple machines: + (machine0)$ {sys.argv[0]} training_set validation_set experiment_dir --machine-rank 0 --num-machines 2 --dist-url [--other-flags] + (machine1)$ {sys.argv[0]} training_set validation_set experiment_dir --machine-rank 1 --num-machines 2 --dist-url [--other-flags] + """ + + parser = Args(epilog=epilog) + parser.add_argument("training_root", type=pathlib.Path, help="Path to the training data.") + parser.add_argument("validation_root", type=pathlib.Path, help="Path to the validation data.") + parser.add_argument( + "experiment_dir", + type=pathlib.Path, + help="Path to the experiment directory.", + ) + parser.add_argument( + "--cfg", + dest="cfg_file", + help="Config file for training. Can be either a local file or a remote URL.", + required=True, + type=file_or_url, + ) + parser.add_argument( + "--initialization-checkpoint", + type=file_or_url, + help="If this value is set to a proper checkpoint when training starts, " + "the model will be initialized with the weights given. " + "No other keys in the checkpoint will be loaded. " + "When another checkpoint would be available and the --resume flag is used, " + "this flag is ignored. This can be a path to a file or an URL. " + "If a URL is given the checkpoint will first be downloaded to the environmental variable " + "`DIRECT_MODEL_DOWNLOAD_DIR` (default=current directory).", + ) + parser.add_argument("--resume", help="Resume training if possible.", action="store_true") + parser.add_argument( + "--force-validation", + help="Start with a validation round, when recovering from a crash. " + "If you use this option, be aware that when combined with --resume, " + "each new run will start with a validation round.", + action="store_true", + ) + parser.add_argument("--name", help="Run name.", required=False, type=str) + + args = parser.parse_args() + if args.initialization_images is not None and args.initialization_kspace is not None: sys.exit("--initialization-images and --initialization-kspace are mutually exclusive.") check_train_val(args.initialization_images, "initialization-images")