-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
direct predict
option for inference and minor fixes (#151)
* Replace predict script with `direct predict` * Check if sewar is installed when using calgary_campinas_vif metric
- Loading branch information
1 parent
7740308
commit 7585ab9
Showing
11 changed files
with
188 additions
and
267 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# coding=utf-8 | ||
# Copyright (c) DIRECT Contributors | ||
|
||
import argparse | ||
import pathlib | ||
|
||
from direct.cli.utils import file_or_url | ||
from direct.environment import Args | ||
from direct.predict import predict_from_argparse | ||
|
||
|
||
def register_parser(parser: argparse._SubParsersAction): | ||
"""Register wsi commands to a root parser.""" | ||
|
||
epilog = f""" | ||
Examples: | ||
--------- | ||
Run on single machine: | ||
$ direct predict <data_root> <output_directory> <experiment_directory> --checkpoint <checkpoint> \ | ||
--num-gpus <num_gpus> [ --cfg <cfg_filename>.yaml --other-flags <other_flags>] | ||
Run on multiple machines: | ||
(machine0)$ direct predict <data_root> <output_directory> <experiment_dir> --checkpoint <checkpoint> \ | ||
--cfg <cfg_filename>.yaml --machine-rank 0 --num-machines 2 [--other-flags] | ||
(machine1)$ direct predict <data_root> <output_directory> <experiment_dir> --checkpoint <checkpoint> \ | ||
--cfg <cfg_filename>.yaml --machine-rank 1 --num-machines 2 [--other-flags] | ||
""" | ||
common_parser = Args(add_help=False) | ||
predict_parser = parser.add_parser( | ||
"predict", | ||
help="Run inference using direct.", | ||
parents=[common_parser], | ||
epilog=epilog, | ||
formatter_class=argparse.RawDescriptionHelpFormatter, | ||
) | ||
predict_parser.add_argument("data_root", type=pathlib.Path, help="Path to the inference data directory.") | ||
predict_parser.add_argument("output_directory", type=pathlib.Path, help="Path to the output directory.") | ||
predict_parser.add_argument( | ||
"experiment_directory", | ||
type=pathlib.Path, | ||
help="Path to the directory with checkpoints and config.", | ||
) | ||
predict_parser.add_argument( | ||
"--checkpoint", | ||
type=int, | ||
required=True, | ||
help="Number of an existing checkpoint in experiment directory.", | ||
) | ||
predict_parser.add_argument( | ||
"--filenames-filter", | ||
type=pathlib.Path, | ||
help="Path to list of filenames to parse.", | ||
) | ||
predict_parser.add_argument( | ||
"--name", | ||
dest="name", | ||
help="Run name if this is different than the experiment directory.", | ||
required=False, | ||
type=str, | ||
default="", | ||
) | ||
predict_parser.add_argument( | ||
"--cfg", | ||
dest="cfg_file", | ||
help="Config file for inference. Can be either a local file or a remote URL." | ||
"Only use it to overwrite the standard loading of the config in the project directory.", | ||
required=False, | ||
type=file_or_url, | ||
) | ||
|
||
predict_parser.set_defaults(subcommand=predict_from_argparse) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# coding=utf-8 | ||
# Copyright (c) DIRECT Contributors | ||
import argparse | ||
import functools | ||
import logging | ||
import os | ||
|
||
import torch | ||
|
||
from direct.common.subsample import build_masking_function | ||
from direct.environment import Args | ||
from direct.inference import build_inference_transforms, setup_inference_save_to_h5 | ||
from direct.launch import launch | ||
from direct.utils import set_all_seeds | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _get_transforms(env): | ||
dataset_cfg = env.cfg.inference.dataset | ||
mask_func = build_masking_function(**dataset_cfg.transforms.masking) | ||
transforms = build_inference_transforms(env, mask_func, dataset_cfg) | ||
return dataset_cfg, transforms | ||
|
||
|
||
setup_inference_save_to_h5 = functools.partial( | ||
setup_inference_save_to_h5, | ||
functools.partial(_get_transforms), | ||
) | ||
|
||
|
||
def predict_from_argparse(args: argparse.Namespace): | ||
# This sets MKL threads to 1. | ||
# DataLoader can otherwise bring a lot of difficulties when computing CPU FFTs in the transforms. | ||
torch.set_num_threads(1) | ||
os.environ["OMP_NUM_THREADS"] = "1" | ||
|
||
set_all_seeds(args.seed) | ||
|
||
launch( | ||
setup_inference_save_to_h5, | ||
args.num_machines, | ||
args.num_gpus, | ||
args.machine_rank, | ||
args.dist_url, | ||
args.name, | ||
args.data_root, | ||
args.experiment_directory, | ||
args.output_directory, | ||
args.filenames_filter, | ||
args.checkpoint, | ||
args.device, | ||
args.num_workers, | ||
args.machine_rank, | ||
args.cfg_file, | ||
args.mixed_precision, | ||
args.debug, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# coding=utf-8 | ||
# Copyright (c) DIRECT Contributors | ||
"""General utilities for module imports""" | ||
|
||
from importlib.util import find_spec | ||
|
||
|
||
def _module_available(module_path: str) -> bool: | ||
""" | ||
Check if a path is available in your environment | ||
>>> _module_available('os') | ||
True | ||
>>> _module_available('bla.bla') | ||
False | ||
Adapted from: https://github.com/PyTorchLightning/pytorch-lightning/blob/ef7d41692ca04bb9877da5c743f80fceecc6a100/pytorch_lightning/utilities/imports.py#L27 | ||
Under Apache 2.0 license. | ||
""" | ||
try: | ||
return find_spec(module_path) is not None | ||
except ModuleNotFoundError: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.