From 7585ab9a6df23494f1dd4cc674474df1ec8c0976 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Sat, 18 Dec 2021 12:33:09 +0100 Subject: [PATCH] `direct predict` option for inference and minor fixes (#151) * Replace predict script with `direct predict` * Check if sewar is installed when using calgary_campinas_vif metric --- direct/cli/__init__.py | 6 +- direct/cli/predict.py | 71 ++++++++++ direct/functionals/challenges.py | 15 ++- direct/inference.py | 3 - direct/predict.py | 58 +++++++++ direct/utils/imports.py | 21 +++ projects/calgary_campinas/predict_test.py | 4 +- projects/calgary_campinas/predict_val.py | 116 ----------------- .../predict_val.py | 19 +-- .../spie_radial_subsampling/predict_test.py | 123 ------------------ tests/tests_utils/test_imports.py | 19 +++ 11 files changed, 188 insertions(+), 267 deletions(-) create mode 100644 direct/cli/predict.py create mode 100644 direct/predict.py create mode 100644 direct/utils/imports.py delete mode 100644 projects/calgary_campinas/predict_val.py rename projects/{spie_radial_subsampling => }/predict_val.py (83%) delete mode 100644 projects/spie_radial_subsampling/predict_test.py create mode 100644 tests/tests_utils/test_imports.py diff --git a/direct/cli/__init__.py b/direct/cli/__init__.py index 36518f18..d83a81fd 100644 --- a/direct/cli/__init__.py +++ b/direct/cli/__init__.py @@ -1,6 +1,7 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors """DIRECT Command-line interface. This is the file which builds the main parser.""" + import argparse import sys @@ -16,11 +17,14 @@ def main(): root_subparsers.required = True root_subparsers.dest = "subcommand" - # Prevent circular import + # Prevent circular imports from direct.cli.train import register_parser as register_train_subcommand + from direct.cli.predict import register_parser as register_predict_subcommand # Training images related commands. register_train_subcommand(root_subparsers) + # Inference images related commands. + register_predict_subcommand(root_subparsers) args = root_parser.parse_args() args.subcommand(args) diff --git a/direct/cli/predict.py b/direct/cli/predict.py new file mode 100644 index 00000000..310dc621 --- /dev/null +++ b/direct/cli/predict.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import argparse +import pathlib + +from direct.cli.utils import file_or_url +from direct.environment import Args +from direct.predict import predict_from_argparse + + +def register_parser(parser: argparse._SubParsersAction): + """Register wsi commands to a root parser.""" + + epilog = f""" + Examples: + --------- + Run on single machine: + $ direct predict --checkpoint \ + --num-gpus [ --cfg .yaml --other-flags ] + + Run on multiple machines: + (machine0)$ direct predict --checkpoint \ + --cfg .yaml --machine-rank 0 --num-machines 2 [--other-flags] + (machine1)$ direct predict --checkpoint \ + --cfg .yaml --machine-rank 1 --num-machines 2 [--other-flags] + """ + common_parser = Args(add_help=False) + predict_parser = parser.add_parser( + "predict", + help="Run inference using direct.", + parents=[common_parser], + epilog=epilog, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + predict_parser.add_argument("data_root", type=pathlib.Path, help="Path to the inference data directory.") + predict_parser.add_argument("output_directory", type=pathlib.Path, help="Path to the output directory.") + predict_parser.add_argument( + "experiment_directory", + type=pathlib.Path, + help="Path to the directory with checkpoints and config.", + ) + predict_parser.add_argument( + "--checkpoint", + type=int, + required=True, + help="Number of an existing checkpoint in experiment directory.", + ) + predict_parser.add_argument( + "--filenames-filter", + type=pathlib.Path, + help="Path to list of filenames to parse.", + ) + predict_parser.add_argument( + "--name", + dest="name", + help="Run name if this is different than the experiment directory.", + required=False, + type=str, + default="", + ) + predict_parser.add_argument( + "--cfg", + dest="cfg_file", + help="Config file for inference. Can be either a local file or a remote URL." + "Only use it to overwrite the standard loading of the config in the project directory.", + required=False, + type=file_or_url, + ) + + predict_parser.set_defaults(subcommand=predict_from_argparse) diff --git a/direct/functionals/challenges.py b/direct/functionals/challenges.py index df72d8ef..f1f953bc 100644 --- a/direct/functionals/challenges.py +++ b/direct/functionals/challenges.py @@ -83,8 +83,17 @@ def calgary_campinas_psnr(gt, pred): def calgary_campinas_vif(gt, pred): def vif_func(gt, target, data_range): # noqa - from sewar.full_ref import vifp - - return vifp(gt, target, sigma_nsq=0.4) + from direct.utils.imports import _module_available + + # Calgary Campinas VIF metric requires 'sewar' module. Check if it exists + if not _module_available("sewar"): + raise RuntimeError( + "'sewar' module required for calgary_campinas_vif metric, but not found. " + "Please use 'pip3 install sewar' and run again." + ) + else: + from sewar.full_ref import vifp + + return vifp(gt, target, sigma_nsq=0.4) return _calgary_campinas_metric(gt, pred, vif_func) diff --git a/direct/inference.py b/direct/inference.py index a0cab418..c6bc1d20 100644 --- a/direct/inference.py +++ b/direct/inference.py @@ -30,7 +30,6 @@ def setup_inference_save_to_h5( machine_rank: int, cfg_file=None, process_per_chunk: Optional[int] = None, - volume_processing_func: Callable = None, mixed_precision: bool = False, debug: bool = False, ): @@ -51,7 +50,6 @@ def setup_inference_save_to_h5( machine_rank : cfg_file : process_per_chunk : - volume_processing_func : mixed_precision : debug : @@ -94,7 +92,6 @@ def setup_inference_save_to_h5( write_output_to_h5( output, output_directory, - volume_processing_func, output_key="reconstruction", ) diff --git a/direct/predict.py b/direct/predict.py new file mode 100644 index 00000000..c0f632ef --- /dev/null +++ b/direct/predict.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors +import argparse +import functools +import logging +import os + +import torch + +from direct.common.subsample import build_masking_function +from direct.environment import Args +from direct.inference import build_inference_transforms, setup_inference_save_to_h5 +from direct.launch import launch +from direct.utils import set_all_seeds + +logger = logging.getLogger(__name__) + + +def _get_transforms(env): + dataset_cfg = env.cfg.inference.dataset + mask_func = build_masking_function(**dataset_cfg.transforms.masking) + transforms = build_inference_transforms(env, mask_func, dataset_cfg) + return dataset_cfg, transforms + + +setup_inference_save_to_h5 = functools.partial( + setup_inference_save_to_h5, + functools.partial(_get_transforms), +) + + +def predict_from_argparse(args: argparse.Namespace): + # This sets MKL threads to 1. + # DataLoader can otherwise bring a lot of difficulties when computing CPU FFTs in the transforms. + torch.set_num_threads(1) + os.environ["OMP_NUM_THREADS"] = "1" + + set_all_seeds(args.seed) + + launch( + setup_inference_save_to_h5, + args.num_machines, + args.num_gpus, + args.machine_rank, + args.dist_url, + args.name, + args.data_root, + args.experiment_directory, + args.output_directory, + args.filenames_filter, + args.checkpoint, + args.device, + args.num_workers, + args.machine_rank, + args.cfg_file, + args.mixed_precision, + args.debug, + ) diff --git a/direct/utils/imports.py b/direct/utils/imports.py new file mode 100644 index 00000000..7a107262 --- /dev/null +++ b/direct/utils/imports.py @@ -0,0 +1,21 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors +"""General utilities for module imports""" + +from importlib.util import find_spec + + +def _module_available(module_path: str) -> bool: + """ + Check if a path is available in your environment + >>> _module_available('os') + True + >>> _module_available('bla.bla') + False + Adapted from: https://github.com/PyTorchLightning/pytorch-lightning/blob/ef7d41692ca04bb9877da5c743f80fceecc6a100/pytorch_lightning/utilities/imports.py#L27 + Under Apache 2.0 license. + """ + try: + return find_spec(module_path) is not None + except ModuleNotFoundError: + return False diff --git a/projects/calgary_campinas/predict_test.py b/projects/calgary_campinas/predict_test.py index b8cb392a..57f42d5c 100644 --- a/projects/calgary_campinas/predict_test.py +++ b/projects/calgary_campinas/predict_test.py @@ -1,5 +1,6 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors + import functools import logging import os @@ -16,8 +17,6 @@ from direct.inference import build_inference_transforms, setup_inference_save_to_h5 from direct.utils import set_all_seeds -from .utils import volume_post_processing_func - logger = logging.getLogger(__name__) @@ -150,7 +149,6 @@ def _get_transforms(masks_dict, env): args.device, args.num_workers, args.machine_rank, - volume_post_processing_func, args.mixed_precision, args.debug, ) diff --git a/projects/calgary_campinas/predict_val.py b/projects/calgary_campinas/predict_val.py deleted file mode 100644 index 0acb1d70..00000000 --- a/projects/calgary_campinas/predict_val.py +++ /dev/null @@ -1,116 +0,0 @@ -# coding=utf-8 -# Copyright (c) DIRECT Contributors -import functools -import logging -import os -import pathlib -import sys - -import torch - -import direct.launch -from direct.common.subsample import build_masking_function -from direct.environment import Args -from direct.inference import build_inference_transforms, setup_inference_save_to_h5 -from direct.utils import set_all_seeds - -from utils import volume_post_processing_func as calgary_campinas_post_processing_func - -logger = logging.getLogger(__name__) - - -def _get_transforms(validation_index, env): - dataset_cfg = env.cfg.validation.datasets[validation_index] - mask_func = build_masking_function(**dataset_cfg.transforms.masking) - transforms = build_inference_transforms(env, mask_func, dataset_cfg) - return dataset_cfg, transforms - - -if __name__ == "__main__": - # This sets MKL threads to 1. - # DataLoader can otherwise bring a lot of difficulties when computing CPU FFTs in the transforms. - torch.set_num_threads(1) - os.environ["OMP_NUM_THREADS"] = "1" - - epilog = f""" - Examples: - Run on single machine: - $ {sys.argv[0]} data_root output_directory --checkpoint --name [--other-flags] - Run on multiple machines: - (machine0)$ {sys.argv[0]} data_root output_directory --checkpoint --name --machine-rank 0 --num-machines 2 --dist-url [--other-flags] - (machine1)$ {sys.argv[0]} data_root output_directory --checkpoint --name --machine-rank 1 --num-machines 2 --dist-url [--other-flags] - """ - - parser = Args(epilog=epilog) - parser.add_argument("data_root", type=pathlib.Path, help="Path to the DoIterationOutput directory.") - parser.add_argument("output_directory", type=pathlib.Path, help="Path to the DoIterationOutput directory.") - parser.add_argument( - "experiment_directory", - type=pathlib.Path, - help="Path to the directory with checkpoints and config.", - ) - parser.add_argument( - "--checkpoint", - type=int, - required=True, - help="Number of an existing checkpoint.", - ) - parser.add_argument( - "--validation-index", - type=int, - required=True, - help="This is the index of the validation set in the config, e.g., 0 will select the first validation set.", - ) - parser.add_argument( - "--filenames-filter", - type=pathlib.Path, - help="Path to list of filenames to parse.", - ) - parser.add_argument("--name", help="Run name.", required=True, type=str) - parser.add_argument( - "--cfg", - dest="cfg_file", - help="Config file for inference. " - "Only use it to overwrite the standard loading of the config in the project directory.", - required=False, - type=pathlib.Path, - ) - parser.add_argument( - "--use-orthogonal-normalization", - dest="use_orthogonal_normalization", - help="If set, an orthogonal normalization (e.g. ortho in numpy.fft) will be used. " - "The Calgary-Campinas challenge does not use this, therefore the volumes will be" - " normalized to their expected outputs.", - default="store_true", - ) - - args = parser.parse_args() - set_all_seeds(args.seed) - - setup_inference_save_to_h5 = functools.partial( - setup_inference_save_to_h5, - functools.partial(_get_transforms, args.validation_index), - ) - volume_post_processing_func = None - if not args.use_orthogonal_normalization: - volume_post_processing_func = calgary_campinas_post_processing_func - - direct.launch.launch( - setup_inference_save_to_h5, - args.num_machines, - args.num_gpus, - args.machine_rank, - args.dist_url, - args.name, - args.data_root, - args.experiment_directory, - args.output_directory, - args.filenames_filter, - args.checkpoint, - args.device, - args.num_workers, - args.machine_rank, - volume_post_processing_func, - args.mixed_precision, - args.debug, - ) diff --git a/projects/spie_radial_subsampling/predict_val.py b/projects/predict_val.py similarity index 83% rename from projects/spie_radial_subsampling/predict_val.py rename to projects/predict_val.py index 209c63ce..d4399aaa 100644 --- a/projects/spie_radial_subsampling/predict_val.py +++ b/projects/predict_val.py @@ -1,12 +1,12 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors + import functools import logging import os import pathlib import sys -import numpy as np import torch import direct.launch @@ -18,11 +18,6 @@ logger = logging.getLogger(__name__) -def _calgary_volume_post_processing_func(volume): - volume = volume / np.sqrt(np.prod(volume.shape[1:])) - return volume - - def _get_transforms(validation_index, env): dataset_cfg = env.cfg.validation.datasets[validation_index] mask_func = build_masking_function(**dataset_cfg.transforms.masking) @@ -86,14 +81,6 @@ def _get_transforms(validation_index, env): required=False, type=pathlib.Path, ) - parser.add_argument( - "--use-orthogonal-normalization", - dest="use_orthogonal_normalization", - help="If set, an orthogonal normalization (e.g. ortho in numpy.fft) will be used. " - "The Calgary-Campinas challenge does not use this, therefore the volumes will be" - " normalized to their expected outputs.", - default="store_true", - ) args = parser.parse_args() set_all_seeds(args.seed) @@ -102,9 +89,6 @@ def _get_transforms(validation_index, env): setup_inference_save_to_h5, functools.partial(_get_transforms, args.validation_index), ) - volume_post_processing_func = None - if not args.use_orthogonal_normalization: - volume_post_processing_func = _calgary_volume_post_processing_func direct.launch.launch( setup_inference_save_to_h5, @@ -122,7 +106,6 @@ def _get_transforms(validation_index, env): args.num_workers, args.machine_rank, args.cfg_file, - volume_post_processing_func, args.mixed_precision, args.debug, ) diff --git a/projects/spie_radial_subsampling/predict_test.py b/projects/spie_radial_subsampling/predict_test.py deleted file mode 100644 index 1a6a9d6f..00000000 --- a/projects/spie_radial_subsampling/predict_test.py +++ /dev/null @@ -1,123 +0,0 @@ -# coding=utf-8 -# Copyright (c) DIRECT Contributors -import functools -import logging -import os -import pathlib -import sys - -import numpy as np -import torch - -import direct.launch -from direct.common.subsample import build_masking_function -from direct.environment import Args -from direct.inference import build_inference_transforms, setup_inference_save_to_h5 -from direct.utils import set_all_seeds - -logger = logging.getLogger(__name__) - - -def _calgary_volume_post_processing_func(volume): - volume = volume / np.sqrt(np.prod(volume.shape[1:])) - return volume - - -def _get_transforms(env): - dataset_cfg = env.cfg.inference.dataset - mask_func = build_masking_function(**dataset_cfg.transforms.masking) - transforms = build_inference_transforms(env, mask_func, dataset_cfg) - return dataset_cfg, transforms - - -if __name__ == "__main__": - # This sets MKL threads to 1. - # DataLoader can otherwise bring a lot of difficulties when computing CPU FFTs in the transforms. - torch.set_num_threads(1) - os.environ["OMP_NUM_THREADS"] = "1" - - epilog = f""" - Examples: - Run on single machine: - $ {sys.argv[0]} data_root output_directory --checkpoint --name [--other-flags] - Run on multiple machines: - (machine0)$ {sys.argv[0]} data_root output_directory --checkpoint --name --machine-rank 0 --num-machines 2 --dist-url [--other-flags] - (machine1)$ {sys.argv[0]} data_root output_directory --checkpoint --name --machine-rank 1 --num-machines 2 --dist-url [--other-flags] - """ - - parser = Args(epilog=epilog) - parser.add_argument("data_root", type=pathlib.Path, help="Path to the data directory.") - parser.add_argument("output_directory", type=pathlib.Path, help="Path to the DoIterationOutput directory.") - parser.add_argument( - "experiment_directory", - type=pathlib.Path, - help="Path to the directory with checkpoints and config.", - ) - parser.add_argument( - "--checkpoint", - type=int, - required=True, - help="Number of an existing checkpoint.", - ) - parser.add_argument( - "--filenames-filter", - type=pathlib.Path, - help="Path to list of filenames to parse.", - ) - parser.add_argument( - "--name", - dest="name", - help="Run name if this is different experiment directory.", - required=False, - type=str, - default="", - ) - parser.add_argument( - "--cfg", - dest="cfg_file", - help="Config file for inference. " - "Only use it to overwrite the standard loading of the config in the project directory.", - required=False, - type=pathlib.Path, - ) - parser.add_argument( - "--use-orthogonal-normalization", - dest="use_orthogonal_normalization", - help="If set, an orthogonal normalization (e.g. ortho in numpy.fft) will be used. " - "The Calgary-Campinas challenge does not use this, therefore the volumes will be" - " normalized to their expected outputs.", - default="store_true", - ) - - args = parser.parse_args() - set_all_seeds(args.seed) - - setup_inference_save_to_h5 = functools.partial( - setup_inference_save_to_h5, - functools.partial(_get_transforms), - ) - - volume_post_processing_func = None - if not args.use_orthogonal_normalization: - volume_post_processing_func = _calgary_volume_post_processing_func - - direct.launch.launch( - setup_inference_save_to_h5, - args.num_machines, - args.num_gpus, - args.machine_rank, - args.dist_url, - args.name, - args.data_root, - args.experiment_directory, - args.output_directory, - args.filenames_filter, - args.checkpoint, - args.device, - args.num_workers, - args.machine_rank, - args.cfg_file, - volume_post_processing_func, - args.mixed_precision, - args.debug, - ) diff --git a/tests/tests_utils/test_imports.py b/tests/tests_utils/test_imports.py new file mode 100644 index 00000000..96c66d34 --- /dev/null +++ b/tests/tests_utils/test_imports.py @@ -0,0 +1,19 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors + +import pytest + +from direct.utils.imports import _module_available + + +@pytest.mark.parametrize( + ["module", "is_available"], + [ + ("torch", True), + ("numpy", True), + ("non-existent", False), + ], +) +def test_module_available(module, is_available): + + assert _module_available(module) == is_available