Skip to content

Commit

Permalink
direct predict option for inference and minor fixes (#151)
Browse files Browse the repository at this point in the history
* Replace predict script with `direct predict`
* Check if sewar is installed when using calgary_campinas_vif metric
  • Loading branch information
georgeyiasemis authored Dec 18, 2021
1 parent 7740308 commit 7585ab9
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 267 deletions.
6 changes: 5 additions & 1 deletion direct/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
"""DIRECT Command-line interface. This is the file which builds the main parser."""

import argparse
import sys

Expand All @@ -16,11 +17,14 @@ def main():
root_subparsers.required = True
root_subparsers.dest = "subcommand"

# Prevent circular import
# Prevent circular imports
from direct.cli.train import register_parser as register_train_subcommand
from direct.cli.predict import register_parser as register_predict_subcommand

# Training images related commands.
register_train_subcommand(root_subparsers)
# Inference images related commands.
register_predict_subcommand(root_subparsers)

args = root_parser.parse_args()
args.subcommand(args)
Expand Down
71 changes: 71 additions & 0 deletions direct/cli/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

import argparse
import pathlib

from direct.cli.utils import file_or_url
from direct.environment import Args
from direct.predict import predict_from_argparse


def register_parser(parser: argparse._SubParsersAction):
"""Register wsi commands to a root parser."""

epilog = f"""
Examples:
---------
Run on single machine:
$ direct predict <data_root> <output_directory> <experiment_directory> --checkpoint <checkpoint> \
--num-gpus <num_gpus> [ --cfg <cfg_filename>.yaml --other-flags <other_flags>]
Run on multiple machines:
(machine0)$ direct predict <data_root> <output_directory> <experiment_dir> --checkpoint <checkpoint> \
--cfg <cfg_filename>.yaml --machine-rank 0 --num-machines 2 [--other-flags]
(machine1)$ direct predict <data_root> <output_directory> <experiment_dir> --checkpoint <checkpoint> \
--cfg <cfg_filename>.yaml --machine-rank 1 --num-machines 2 [--other-flags]
"""
common_parser = Args(add_help=False)
predict_parser = parser.add_parser(
"predict",
help="Run inference using direct.",
parents=[common_parser],
epilog=epilog,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
predict_parser.add_argument("data_root", type=pathlib.Path, help="Path to the inference data directory.")
predict_parser.add_argument("output_directory", type=pathlib.Path, help="Path to the output directory.")
predict_parser.add_argument(
"experiment_directory",
type=pathlib.Path,
help="Path to the directory with checkpoints and config.",
)
predict_parser.add_argument(
"--checkpoint",
type=int,
required=True,
help="Number of an existing checkpoint in experiment directory.",
)
predict_parser.add_argument(
"--filenames-filter",
type=pathlib.Path,
help="Path to list of filenames to parse.",
)
predict_parser.add_argument(
"--name",
dest="name",
help="Run name if this is different than the experiment directory.",
required=False,
type=str,
default="",
)
predict_parser.add_argument(
"--cfg",
dest="cfg_file",
help="Config file for inference. Can be either a local file or a remote URL."
"Only use it to overwrite the standard loading of the config in the project directory.",
required=False,
type=file_or_url,
)

predict_parser.set_defaults(subcommand=predict_from_argparse)
15 changes: 12 additions & 3 deletions direct/functionals/challenges.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,17 @@ def calgary_campinas_psnr(gt, pred):

def calgary_campinas_vif(gt, pred):
def vif_func(gt, target, data_range): # noqa
from sewar.full_ref import vifp

return vifp(gt, target, sigma_nsq=0.4)
from direct.utils.imports import _module_available

# Calgary Campinas VIF metric requires 'sewar' module. Check if it exists
if not _module_available("sewar"):
raise RuntimeError(
"'sewar' module required for calgary_campinas_vif metric, but not found. "
"Please use 'pip3 install sewar' and run again."
)
else:
from sewar.full_ref import vifp

return vifp(gt, target, sigma_nsq=0.4)

return _calgary_campinas_metric(gt, pred, vif_func)
3 changes: 0 additions & 3 deletions direct/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def setup_inference_save_to_h5(
machine_rank: int,
cfg_file=None,
process_per_chunk: Optional[int] = None,
volume_processing_func: Callable = None,
mixed_precision: bool = False,
debug: bool = False,
):
Expand All @@ -51,7 +50,6 @@ def setup_inference_save_to_h5(
machine_rank :
cfg_file :
process_per_chunk :
volume_processing_func :
mixed_precision :
debug :
Expand Down Expand Up @@ -94,7 +92,6 @@ def setup_inference_save_to_h5(
write_output_to_h5(
output,
output_directory,
volume_processing_func,
output_key="reconstruction",
)

Expand Down
58 changes: 58 additions & 0 deletions direct/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import argparse
import functools
import logging
import os

import torch

from direct.common.subsample import build_masking_function
from direct.environment import Args
from direct.inference import build_inference_transforms, setup_inference_save_to_h5
from direct.launch import launch
from direct.utils import set_all_seeds

logger = logging.getLogger(__name__)


def _get_transforms(env):
dataset_cfg = env.cfg.inference.dataset
mask_func = build_masking_function(**dataset_cfg.transforms.masking)
transforms = build_inference_transforms(env, mask_func, dataset_cfg)
return dataset_cfg, transforms


setup_inference_save_to_h5 = functools.partial(
setup_inference_save_to_h5,
functools.partial(_get_transforms),
)


def predict_from_argparse(args: argparse.Namespace):
# This sets MKL threads to 1.
# DataLoader can otherwise bring a lot of difficulties when computing CPU FFTs in the transforms.
torch.set_num_threads(1)
os.environ["OMP_NUM_THREADS"] = "1"

set_all_seeds(args.seed)

launch(
setup_inference_save_to_h5,
args.num_machines,
args.num_gpus,
args.machine_rank,
args.dist_url,
args.name,
args.data_root,
args.experiment_directory,
args.output_directory,
args.filenames_filter,
args.checkpoint,
args.device,
args.num_workers,
args.machine_rank,
args.cfg_file,
args.mixed_precision,
args.debug,
)
21 changes: 21 additions & 0 deletions direct/utils/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
"""General utilities for module imports"""

from importlib.util import find_spec


def _module_available(module_path: str) -> bool:
"""
Check if a path is available in your environment
>>> _module_available('os')
True
>>> _module_available('bla.bla')
False
Adapted from: https://github.com/PyTorchLightning/pytorch-lightning/blob/ef7d41692ca04bb9877da5c743f80fceecc6a100/pytorch_lightning/utilities/imports.py#L27
Under Apache 2.0 license.
"""
try:
return find_spec(module_path) is not None
except ModuleNotFoundError:
return False
4 changes: 1 addition & 3 deletions projects/calgary_campinas/predict_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

import functools
import logging
import os
Expand All @@ -16,8 +17,6 @@
from direct.inference import build_inference_transforms, setup_inference_save_to_h5
from direct.utils import set_all_seeds

from .utils import volume_post_processing_func

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -150,7 +149,6 @@ def _get_transforms(masks_dict, env):
args.device,
args.num_workers,
args.machine_rank,
volume_post_processing_func,
args.mixed_precision,
args.debug,
)
116 changes: 0 additions & 116 deletions projects/calgary_campinas/predict_val.py

This file was deleted.

Loading

0 comments on commit 7585ab9

Please sign in to comment.