-
Notifications
You must be signed in to change notification settings - Fork 705
/
Copy pathtorch_inference.py
105 lines (83 loc) · 3.37 KB
/
torch_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Anomalib Torch Inferencer Script.
This script performs torch inference by reading model weights
from command line, and show the visualization results.
"""
# Copyright (C) 2022-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging
from argparse import ArgumentParser, Namespace
from pathlib import Path
import torch
from anomalib.data.utils import generate_output_image_filename, get_image_filenames, read_image
from anomalib.data.utils.image import save_image, show_image
from anomalib.deploy.inferencers.torch_inferencer import TorchInferencer
from anomalib.utils.visualization import ImageResult, ImageVisualizer
logger = logging.getLogger(__name__)
def get_parser() -> ArgumentParser:
"""Get parser.
Returns:
ArgumentParser: The parser object.
"""
parser = ArgumentParser()
parser.add_argument("--weights", type=Path, required=True, help="Path to model weights")
parser.add_argument("--input", type=Path, required=True, help="Path to an image to infer.")
parser.add_argument("--output", type=Path, required=False, help="Path to save the output image.")
parser.add_argument(
"--device",
type=str,
required=False,
default="auto",
help="Device to use for inference. Defaults to auto.",
choices=["auto", "cpu", "gpu", "cuda"], # cuda and gpu are the same but provided for convenience
)
parser.add_argument(
"--task",
type=str,
required=False,
help="Task type.",
default="classification",
choices=["classification", "detection", "segmentation"],
)
parser.add_argument(
"--visualization_mode",
type=str,
required=False,
default="simple",
help="Visualization mode.",
choices=["full", "simple"],
)
parser.add_argument(
"--show",
action="store_true",
required=False,
help="Show the visualized predictions on the screen.",
)
return parser
def infer(args: Namespace) -> None:
"""Infer predictions.
Show/save the output if path is to an image. If the path is a directory, go over each image in the directory.
Args:
args (Namespace): The arguments from the command line.
"""
torch.set_grad_enabled(mode=False)
# Create the inferencer and visualizer.
inferencer = TorchInferencer(path=args.weights, device=args.device)
visualizer = ImageVisualizer(mode=args.visualization_mode, task=args.task)
filenames = get_image_filenames(path=args.input)
for filename in filenames:
image = read_image(filename, as_tensor=True)
predictions = inferencer.predict(image=image)
image_result = ImageResult.from_dataset_item(predictions.items[0])
output = visualizer.visualize_image(image_result)
if args.output is None and args.show is False:
msg = "Neither output path is provided nor show flag is set. Inferencer will run but return nothing."
logger.warning(msg)
if args.output:
file_path = generate_output_image_filename(input_path=filename, output_path=args.output)
save_image(filename=file_path, image=output)
# Show the image in case the flag is set by the user.
if args.show:
show_image(title="Output Image", image=output)
if __name__ == "__main__":
args = get_parser().parse_args()
infer(args=args)