-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
executable file
·69 lines (50 loc) · 2.23 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from runtime.args import non_negative_int, float_0_1, ArgParser
from runtime.utils import set_warning_levels, set_memory_growth, set_tf_flags, set_amp, set_xla
from inference.main_inference import check_test_time_input, load_test_time_models, test_time_inference
def get_predict_args():
p = ArgParser(formatter_class=ArgumentDefaultsHelpFormatter)
p.arg("--models", type=str, help="Directory containing saved models")
p.arg("--config", type=str, help="Path and name of config.json file from results of MIST pipeline")
p.arg("--data", type=str, help="CSV or JSON file containing paths to data")
p.arg("--output", type=str, help="Directory to save predictions")
p.boolean_flag("--fast", default=False, help="Use only one model for prediction to speed up inference time")
p.arg("--gpu", type=int, default=0, help="GPU id to run inference on, -1 --> run on CPU")
p.boolean_flag("--amp", default=False, help="Use automatic mixed precision")
p.boolean_flag("--xla", default=False, help="Use XLA")
p.arg("--sw-overlap",
type=float_0_1,
default=0.5,
help="Amount of overlap between scans during sliding window inference")
p.arg("--blend-mode",
type=str,
choices=["gaussian", "constant"],
default="gaussian",
help="How to blend output of overlapping windows")
p.boolean_flag("--tta", default=False, help="Use test time augmentation")
args = p.parse_args()
return args
def main(args):
# Set warning levels
set_warning_levels()
# Set TF flags
set_tf_flags()
# Set visible device to GPU
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
# Allow memory growth
set_memory_growth()
# Set AMP and XLA if required
if args.amp:
set_amp()
if args.xla:
set_xla()
# Handle inputs
df = check_test_time_input(args.data, args.output)
# Load models
models = load_test_time_models(args.models, args.fast)
# Run test time inference
test_time_inference(df, args.output, args.config, models, args.sw_overlap, args.blend_mode, args.tta)
if __name__ == "__main__":
args = get_predict_args()
main(args)