Skip to content

Commit

Permalink
Revert "Design direct train CLI interface (#136)" (#138)
Browse files Browse the repository at this point in the history
This reverts commit 6508922.
  • Loading branch information
jonasteuwen authored Dec 13, 2021
1 parent 6508922 commit fe43f3a
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 89 deletions.
15 changes: 5 additions & 10 deletions direct/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,25 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
"""DIRECT Command-line interface. This is the file which builds the main parser."""
"""DIRECT Command-line interface. This is the file which builds the main parser. Currently just a placeholder"""
import argparse
import sys


def main():
"""
Console script for direct.
Console script for dlup.
"""
# From https://stackoverflow.com/questions/17073688/how-to-use-argparse-subparsers-correctly
root_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

root_subparsers = root_parser.add_subparsers(help="Direct CLI utils to run.")
root_subparsers = root_parser.add_subparsers(help="DIRECT utilities.")
root_subparsers.required = True
root_subparsers.dest = "subcommand"

# Prevent circular import
from direct.cli.train import register_parser as register_train_subcommand

# Training images related commands.
register_train_subcommand(root_subparsers)

args = root_parser.parse_args()
args.subcommand(args)
return 0


if __name__ == "__main__":
main()
sys.exit(main()) # pragma: no cover
67 changes: 0 additions & 67 deletions direct/cli/train.py

This file was deleted.

6 changes: 0 additions & 6 deletions direct/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) DIRECT Contributors
import argparse
import pathlib
import sys

from direct.utils.io import check_is_valid_url

Expand All @@ -14,8 +13,3 @@ def file_or_url(path):
if path.is_file():
return path
raise argparse.ArgumentTypeError(f"{path} is not a valid file or url.")


def check_train_val(key, name):
if key is not None and len(key) != 2:
sys.exit(f"--{name} has to be of the form `train_folder, validation_folder` if a validation folder is set.")
4 changes: 2 additions & 2 deletions direct/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,12 @@ class Args(argparse.ArgumentParser):
Defines global default arguments.
"""

def __init__(self, epilog=None, add_help=True, **overrides):
def __init__(self, epilog=None, **overrides):
"""
Args:
**overrides (dict, optional): Keyword arguments used to override default argument values
"""
super().__init__(epilog=epilog, formatter_class=argparse.RawDescriptionHelpFormatter, add_help=add_help)
super().__init__(epilog=epilog, formatter_class=argparse.RawDescriptionHelpFormatter)

self.add_argument(
"--device",
Expand Down
7 changes: 7 additions & 0 deletions tools/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# Direct tools

Scripts are provided:
- To train a model use `train_model.py`.
- To extract the best checkpoint based on `metrics.json`, use `parse_metrics_log.py`.


## Tips and tricks

- We are using a lot of experimental features in pytorch, to reduce such warnings you can use
`export PYTHONWARNINGS="ignore"` in the shell before execution.
59 changes: 55 additions & 4 deletions direct/train.py → tools/train_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import argparse
import functools
import logging
import os
Expand All @@ -13,12 +12,12 @@
import numpy as np
import torch

from direct.cli.utils import check_train_val
from direct.cli.utils import file_or_url
from direct.common.subsample import build_masking_function
from direct.data.datasets import build_dataset_from_input
from direct.data.lr_scheduler import WarmupMultiStepLR
from direct.data.mri_transforms import build_mri_transforms
from direct.environment import setup_training_environment
from direct.environment import Args, setup_training_environment
from direct.launch import launch
from direct.utils import remove_keys, set_all_seeds, str_to_class
from direct.utils.dataset import get_filenames_for_datasets
Expand Down Expand Up @@ -259,7 +258,12 @@ def setup_train(
)


def train_from_argparse(args: argparse.Namespace):
def check_train_val(key, name):
if key is not None and len(key) != 2:
sys.exit(f"--{name} has to be of the form `train_folder, validation_folder` if a validation folder is set.")


if __name__ == "__main__":
# This sets MKL threads to 1.
# DataLoader can otherwise bring a l ot of difficulties when computing CPU FFTs in the transforms.
torch.set_num_threads(1)
Expand All @@ -268,6 +272,53 @@ def train_from_argparse(args: argparse.Namespace):
# Remove warnings from named tensors being experimental
os.environ["PYTHONWARNINGS"] = "ignore"

epilog = f"""
Examples:
Run on single machine:
$ {sys.argv[0]} training_set validation_set experiment_dir --num-gpus 8 --cfg cfg.yaml
Run on multiple machines:
(machine0)$ {sys.argv[0]} training_set validation_set experiment_dir --machine-rank 0 --num-machines 2 --dist-url <URL> [--other-flags]
(machine1)$ {sys.argv[0]} training_set validation_set experiment_dir --machine-rank 1 --num-machines 2 --dist-url <URL> [--other-flags]
"""

parser = Args(epilog=epilog)
parser.add_argument("training_root", type=pathlib.Path, help="Path to the training data.")
parser.add_argument("validation_root", type=pathlib.Path, help="Path to the validation data.")
parser.add_argument(
"experiment_dir",
type=pathlib.Path,
help="Path to the experiment directory.",
)
parser.add_argument(
"--cfg",
dest="cfg_file",
help="Config file for training. Can be either a local file or a remote URL.",
required=True,
type=file_or_url,
)
parser.add_argument(
"--initialization-checkpoint",
type=file_or_url,
help="If this value is set to a proper checkpoint when training starts, "
"the model will be initialized with the weights given. "
"No other keys in the checkpoint will be loaded. "
"When another checkpoint would be available and the --resume flag is used, "
"this flag is ignored. This can be a path to a file or an URL. "
"If a URL is given the checkpoint will first be downloaded to the environmental variable "
"`DIRECT_MODEL_DOWNLOAD_DIR` (default=current directory).",
)
parser.add_argument("--resume", help="Resume training if possible.", action="store_true")
parser.add_argument(
"--force-validation",
help="Start with a validation round, when recovering from a crash. "
"If you use this option, be aware that when combined with --resume, "
"each new run will start with a validation round.",
action="store_true",
)
parser.add_argument("--name", help="Run name.", required=False, type=str)

args = parser.parse_args()

if args.initialization_images is not None and args.initialization_kspace is not None:
sys.exit("--initialization-images and --initialization-kspace are mutually exclusive.")
check_train_val(args.initialization_images, "initialization-images")
Expand Down

0 comments on commit fe43f3a

Please sign in to comment.