Skip to content

Commit

Permalink
aggregate all lines instead of per region to better utilise batched p…
Browse files Browse the repository at this point in the history
…redictor
  • Loading branch information
bertsky committed Sep 17, 2024
1 parent 7aae9bc commit 9611e2c
Showing 1 changed file with 178 additions and 178 deletions.
356 changes: 178 additions & 178 deletions ocrd_calamari/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def setup(self):
"""
resolved = self.resolve_resource(self.parameter["checkpoint_dir"])
checkpoints = glob("%s/*.ckpt.json" % resolved)
self.predictor = MultiPredictor(checkpoints=checkpoints)
self.predictor = MultiPredictor(checkpoints=checkpoints, batch_size=BATCH_SIZE)

self.network_input_channels = self.predictor.predictors[
0
Expand Down Expand Up @@ -98,6 +98,7 @@ def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional
page, page_id, feature_selector=self.features
)

lines = []
for region in page.get_AllRegions(classes=["Text"]):
region_image, region_coords = self.workspace.image_from_segment(
region, page_image, page_coords, feature_selector=self.features
Expand All @@ -109,8 +110,6 @@ def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional
len(textlines),
region.id,
)
line_images_np = []
line_coordss = []
for line in textlines:
self.logger.debug(
"Recognizing line '%s' in region '%s'", line.id, region.id
Expand Down Expand Up @@ -150,195 +149,196 @@ def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional
line.id,
region.id,
)
line_image_np = np.array([[0]], dtype=np.uint8)
else:
line_image_np = np.array(line_image, dtype=np.uint8)
line_images_np.append(line_image_np)
line_coordss.append(line_coords)

# avoid too large a batch size (causing OOM on CPU or GPU)
fun = lambda x: self.predictor.predict_raw(x, progress_bar=False)
raw_results_all = itertools.chain.from_iterable(
map(fun, itertools.batched(line_images_np, BATCH_SIZE)))

for line, line_coords, raw_results in zip(
textlines, line_coordss, raw_results_all
):
for i, p in enumerate(raw_results):
p.prediction.id = "fold_{}".format(i)

prediction = self.voter.vote_prediction_result(raw_results)
prediction.id = "voted"

# Build line text on our own
#
# Calamari does whitespace post-processing on prediction.sentence,
# while it does not do the same on prediction.positions. Do it on
# our own to have consistency.
#
# XXX Check Calamari's built-in post-processing on
# prediction.sentence

def _sort_chars(p):
"""Filter and sort chars of prediction p"""
chars = p.chars
chars = [
c for c in chars if c.char
] # XXX Note that omission probabilities are not normalized?!
chars = [
c
for c in chars
if c.probability >= self.parameter["glyph_conf_cutoff"]
]
chars = sorted(chars, key=lambda k: k.probability, reverse=True)
return chars

def _drop_leading_spaces(positions):
return list(
itertools.dropwhile(
lambda p: _sort_chars(p)[0].char == " ", positions
)
continue
lines.append((line, line_coords, np.array(line_image, dtype=np.uint8)))

if not len(lines):
self.logger.warning("No text lines on page '%s'", page_id)
return OcrdPageResult(pcgts)

lines, coords, images = zip(*lines)
# not exposed in MultiPredictor yet, cf. calamari#361:
# results = self.predictor.predict_raw(images, progress_bar=False, batch_size=BATCH_SIZE)
# avoid too large a batch size (causing OOM on CPU or GPU)
fun = lambda x: self.predictor.predict_raw(x, progress_bar=False)
results = itertools.chain.from_iterable(
map(fun, itertools.batched(images, BATCH_SIZE)))
for line, line_coords, raw_results in zip(lines, coords, results):
for i, p in enumerate(raw_results):
p.prediction.id = "fold_{}".format(i)

prediction = self.voter.vote_prediction_result(raw_results)
prediction.id = "voted"

# Build line text on our own
#
# Calamari does whitespace post-processing on prediction.sentence,
# while it does not do the same on prediction.positions. Do it on
# our own to have consistency.
#
# XXX Check Calamari's built-in post-processing on
# prediction.sentence

def _sort_chars(p):
"""Filter and sort chars of prediction p"""
chars = p.chars
chars = [
c for c in chars if c.char
] # XXX Note that omission probabilities are not normalized?!
chars = [
c
for c in chars
if c.probability >= self.parameter["glyph_conf_cutoff"]
]
chars = sorted(chars, key=lambda k: k.probability, reverse=True)
return chars

def _drop_leading_spaces(positions):
return list(
itertools.dropwhile(
lambda p: _sort_chars(p)[0].char == " ", positions
)
)

def _drop_trailing_spaces(positions):
return list(reversed(_drop_leading_spaces(reversed(positions))))

def _drop_double_spaces(positions):
def _drop_double_spaces_generator(positions):
last_was_space = False
for p in positions:
if p.chars[0].char == " ":
if not last_was_space:
yield p
last_was_space = True
else:
yield p
last_was_space = False

return list(_drop_double_spaces_generator(positions))

positions = prediction.positions
positions = _drop_leading_spaces(positions)
positions = _drop_trailing_spaces(positions)
positions = _drop_double_spaces(positions)
positions = list(positions)

line_text = "".join(_sort_chars(p)[0].char for p in positions)
if line_text != prediction.sentence:
self.logger.warning(
f"Our own line text is not the same as Calamari's:"
f"'{line_text}' != '{prediction.sentence}'"
)
def _drop_trailing_spaces(positions):
return list(reversed(_drop_leading_spaces(reversed(positions))))

# Delete existing results
if line.get_TextEquiv():
self.logger.warning("Line '%s' already contained text results", line.id)
line.set_TextEquiv([])
if line.get_Word():
self.logger.warning(
"Line '%s' already contained word segmentation", line.id
)
line.set_Word([])
def _drop_double_spaces(positions):
def _drop_double_spaces_generator(positions):
last_was_space = False
for p in positions:
if p.chars[0].char == " ":
if not last_was_space:
yield p
last_was_space = True
else:
yield p
last_was_space = False

return list(_drop_double_spaces_generator(positions))

positions = prediction.positions
positions = _drop_leading_spaces(positions)
positions = _drop_trailing_spaces(positions)
positions = _drop_double_spaces(positions)
positions = list(positions)

line_text = "".join(_sort_chars(p)[0].char for p in positions)
if line_text != prediction.sentence:
self.logger.warning(
f"Our own line text is not the same as Calamari's:"
f"'{line_text}' != '{prediction.sentence}'"
)

# Save line results
line_conf = prediction.avg_char_probability
line.set_TextEquiv(
[TextEquivType(Unicode=line_text, conf=line_conf)]
# Delete existing results
if line.get_TextEquiv():
self.logger.warning("Line '%s' already contained text results", line.id)
line.set_TextEquiv([])
if line.get_Word():
self.logger.warning(
"Line '%s' already contained word segmentation", line.id
)
line.set_Word([])

# Save word results
#
# Calamari OCR does not provide word positions, so we infer word
# positions from a. text segmentation and b. the glyph positions.
# This is necessary because the PAGE XML format enforces a strict
# hierarchy of lines > words > glyphs.

def _words(s):
"""Split words based on spaces and include spaces as 'words'"""
spaces = None
word = ""
for c in s:
if c == " " and spaces is True:
word += c
elif c != " " and spaces is False:
word += c
else:
if word:
yield word
word = c
spaces = c == " "
yield word

if self.parameter["textequiv_level"] in ["word", "glyph"]:
word_no = 0
i = 0

for word_text in _words(line_text):
word_length = len(word_text)
if not all(c == " " for c in word_text):
word_positions = positions[i : i + word_length]
word_start = word_positions[0].global_start
word_end = word_positions[-1].global_end

polygon = polygon_from_x0y0x1y1(
[word_start, 0, word_end, line_image.height]
)
points = points_from_polygon(
coordinates_for_segment(polygon, None, line_coords)
)
# XXX Crop to line polygon?
# Save line results
line_conf = prediction.avg_char_probability
line.set_TextEquiv(
[TextEquivType(Unicode=line_text, conf=line_conf)]
)

word = WordType(
id="%s_word%04d" % (line.id, word_no),
Coords=CoordsType(points),
)
word.add_TextEquiv(TextEquivType(Unicode=word_text))

if self.parameter["textequiv_level"] == "glyph":
for glyph_no, p in enumerate(word_positions):
glyph_start = p.global_start
glyph_end = p.global_end

polygon = polygon_from_x0y0x1y1(
[
glyph_start,
0,
glyph_end,
line_image.height,
]
)
points = points_from_polygon(
coordinates_for_segment(
polygon, None, line_coords
)
)
# Save word results
#
# Calamari OCR does not provide word positions, so we infer word
# positions from a. text segmentation and b. the glyph positions.
# This is necessary because the PAGE XML format enforces a strict
# hierarchy of lines > words > glyphs.

def _words(s):
"""Split words based on spaces and include spaces as 'words'"""
spaces = None
word = ""
for c in s:
if c == " " and spaces is True:
word += c
elif c != " " and spaces is False:
word += c
else:
if word:
yield word
word = c
spaces = c == " "
yield word

if self.parameter["textequiv_level"] in ["word", "glyph"]:
word_no = 0
i = 0

for word_text in _words(line_text):
word_length = len(word_text)
if not all(c == " " for c in word_text):
word_positions = positions[i : i + word_length]
word_start = word_positions[0].global_start
word_end = word_positions[-1].global_end

polygon = polygon_from_x0y0x1y1(
[word_start, 0, word_end, line_image.height]
)
points = points_from_polygon(
coordinates_for_segment(polygon, None, line_coords)
)
# XXX Crop to line polygon?

glyph = GlyphType(
id="%s_glyph%04d" % (word.id, glyph_no),
Coords=CoordsType(points),
word = WordType(
id="%s_word%04d" % (line.id, word_no),
Coords=CoordsType(points),
)
word.add_TextEquiv(TextEquivType(Unicode=word_text))

if self.parameter["textequiv_level"] == "glyph":
for glyph_no, p in enumerate(word_positions):
glyph_start = p.global_start
glyph_end = p.global_end

polygon = polygon_from_x0y0x1y1(
[
glyph_start,
0,
glyph_end,
line_image.height,
]
)
points = points_from_polygon(
coordinates_for_segment(
polygon, None, line_coords
)
)

# Add predictions (= TextEquivs)
char_index_start = 1
# Index must start with 1, see
# https://ocr-d.github.io/page#multiple-textequivs
for char_index, char in enumerate(
_sort_chars(p), start=char_index_start
):
glyph.add_TextEquiv(
TextEquivType(
Unicode=char.char,
index=char_index,
conf=char.probability,
)
glyph = GlyphType(
id="%s_glyph%04d" % (word.id, glyph_no),
Coords=CoordsType(points),
)

# Add predictions (= TextEquivs)
char_index_start = 1
# Index must start with 1, see
# https://ocr-d.github.io/page#multiple-textequivs
for char_index, char in enumerate(
_sort_chars(p), start=char_index_start
):
glyph.add_TextEquiv(
TextEquivType(
Unicode=char.char,
index=char_index,
conf=char.probability,
)
)

word.add_Glyph(glyph)
word.add_Glyph(glyph)

line.add_Word(word)
word_no += 1
line.add_Word(word)
word_no += 1

i += word_length
i += word_length

_page_update_higher_textequiv_levels("line", pcgts)
return OcrdPageResult(pcgts)
Expand Down

0 comments on commit 9611e2c

Please sign in to comment.