forked from jpuigcerver/PyLaia
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpylaia-htr-decode-ctc
executable file
·114 lines (107 loc) · 3.76 KB
/
pylaia-htr-decode-ctc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
from __future__ import absolute_import
import argparse
import os
import torch
import laia.common.logging as log
from laia.common.arguments import add_argument, args, add_defaults
from laia.common.arguments_types import str2bool
from laia.common.loader import ModelLoader, CheckpointLoader
from laia.data import ImageDataLoader, ImageFromListDataset
from laia.decoders import CTCGreedyDecoder
from laia.engine.feeders import ImageFeeder, ItemFeeder
from laia.experiments import Experiment
from laia.utils import SymbolsTable, ImageToTensor
if __name__ == "__main__":
add_defaults("batch_size", "gpu", "train_path", logging_level="WARNING")
add_argument(
"syms",
type=argparse.FileType("r"),
help="Symbols table mapping from strings to integers",
)
add_argument(
"img_dirs", type=str, nargs="+", help="Directory containing word images"
)
add_argument(
"img_list",
type=argparse.FileType("r"),
help="File or list containing images to decode",
)
add_argument(
"--model_filename", type=str, default="model", help="File name of the model"
)
add_argument(
"--checkpoint",
type=str,
default="experiment.ckpt.lowest-valid-cer*",
help="Name of the model checkpoint to use, can be a glob pattern",
)
add_argument(
"--source",
type=str,
default="experiment",
choices=["experiment", "model"],
help="Type of class which generated the checkpoint",
)
add_argument(
"--print_img_ids",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Print output with the associated image id",
)
add_argument(
"--separator",
type=str,
default=" ",
help="Use this string as the separator between the ids and the output",
)
add_argument("--join_str", type=str, help="Join the output using this")
add_argument(
"--use_letters", action="store_true", help="Print the output with letters"
)
add_argument(
"--space", type=str, help="Replace <space> with this. Used with --use_letters"
)
args = args()
syms = SymbolsTable(args.syms)
device = torch.device("cuda:{}".format(args.gpu - 1) if args.gpu else "cpu")
model = ModelLoader(
args.train_path, filename=args.model_filename, device=device
).load()
if model is None:
log.error("Could not find the model")
exit(1)
state = CheckpointLoader(device=device).load_by(
os.path.join(args.train_path, args.checkpoint)
)
model.load_state_dict(
state if args.source == "model" else Experiment.get_model_state_dict(state)
)
model = model.to(device)
model.eval()
dataset = ImageFromListDataset(
args.img_list, img_dirs=args.img_dirs, img_transform=ImageToTensor()
)
dataset_loader = ImageDataLoader(
dataset=dataset, image_channels=1, batch_size=args.batch_size, num_workers=8
)
batch_input_fn = ImageFeeder(device=device, parent_feeder=ItemFeeder("img"))
decoder = CTCGreedyDecoder()
for batch in dataset_loader:
batch_input = batch_input_fn(batch)
batch_output = model(batch_input)
batch_decode = decoder(batch_output)
for img_id, out in zip(batch["id"], batch_decode):
if args.use_letters:
out = [str(syms[val]) for val in out]
if args.space:
out = [args.space if sym == "<space>" else sym for sym in out]
if args.join_str is not None:
out = args.join_str.join(str(x) for x in out)
print(
"{}{}{}".format(img_id, args.separator, out)
if args.print_img_ids
else out
)