Skip to content

Commit

Permalink
[app][feat] gather pairs from data from dirs
Browse files Browse the repository at this point in the history
  • Loading branch information
M3ssman committed Sep 22, 2023
1 parent 78a126c commit 90ccf29
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 43 deletions.
14 changes: 9 additions & 5 deletions src/tesstrain/training_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@
DEFAULT_BINARIZE = False
DEFAULT_SANITIZE = True
DEFAULT_PADDING = 0
SUMMARY_SUFFIX = '_summary.gt.txt'
SUFFIX_SUMMARY = '_summary.gt.txt'
SUFFIX_GT_TXT_='.gt.txt'
SUFFIX_GT_IMG_TIF = '.tif'

# clear unwanted marks for single wordlike tokens
CLEAR_MARKS = [
Expand Down Expand Up @@ -343,7 +345,7 @@ def pair_prefix(self) -> str:
def pair_prefix(self, pair_prefix):
"""Set output dir explicitely"""

self._path_ocr_data = pair_prefix
self._pair_prefix = pair_prefix

def _resolve_image_path(self, path_xml_data):
self.path_image_data = resolve_image_path(path_xml_data)
Expand Down Expand Up @@ -393,12 +395,14 @@ def write_pair(self, text_line: TextLine,
"""Serialize training data pairs"""

_data_label = Path( self.path_ocr_data).stem
if _data_label.isnumeric():
_data_label = f'p{int(_data_label)}'
_dir_path = os.path.join(self.output_dir, self.pair_prefix)
if not os.path.isdir(_dir_path):
os.makedirs(_dir_path)
gt_txt_name = f'{_data_label}_{self.pair_prefix}_{text_line.element_id}.gt.txt'
gt_txt_name = f'{self.pair_prefix}_{_data_label}_{text_line.element_id}{SUFFIX_GT_TXT_}'
gt_txt_path = os.path.join(_dir_path, gt_txt_name)
img_name = f'{_data_label}_{self.pair_prefix}_{text_line.element_id}.gt.tif'
img_name = f'{self.pair_prefix}_{_data_label}_{text_line.element_id}{SUFFIX_GT_IMG_TIF}'
img_path = os.path.join(_dir_path, img_name)
content = text_line.get_textline_content()
img_frame = extract_rectangular_frame(image_handle, text_line)
Expand All @@ -425,7 +429,7 @@ def write_summary(self, training_datas: List):
"""Serialize training data pairs"""

contents = [d.get_textline_content() + '\n' for d in training_datas]
file_name = self.pair_prefix + SUMMARY_SUFFIX
file_name = self.pair_prefix + SUFFIX_SUMMARY
file_path = os.path.join(self.output_dir, self.pair_prefix, file_name)
with open(file_path, 'w', encoding="utf8") as fhdl:
fhdl.writelines(contents)
Expand Down
133 changes: 96 additions & 37 deletions src/tesstrain/training_sets_cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# -*- coding: utf-8 -*-
"""Generate Sets of Training Data TextLine + Image Pairs"""

import argparse
import os

from argparse import (
ArgumentParser,
Namespace,
)
from pathlib import (
Path
)

from tesstrain.training_sets import (
TrainingSets,
DEFAULT_OUTDIR_PREFIX,
DEFAULT_MIN_CHARS,
DEFAULT_USE_SUMMARY,
Expand All @@ -15,14 +21,88 @@
DEFAULT_SANITIZE,
DEFAULT_BINARIZE,
DEFAULT_PADDING,
SUMMARY_SUFFIX
SUFFIX_SUMMARY,
TrainingSets,
)


def _run_single_page(args: Namespace):
path_ocr = os.path.abspath(args.data)
path_img = os.path.abspath(args.image)
output_dir = os.path.abspath(args.output_dir)
min_chars = args.minchars
do_summary = args.summary
do_reorder = args.reorder
do_binarize = args.binarize
do_opt = args.sanitize
intrusion_ratio = args.intrusion_ratio
if isinstance(intrusion_ratio, str) and ',' in intrusion_ratio:
intrusion_ratio = [float(n) for n in intrusion_ratio.split(',')]
else:
intrusion_ratio = float(intrusion_ratio)
rotation_thresh = args.rotation_threshold
padding = args.padding
intrusion_ratio = args.intrusion_ratio
if isinstance(intrusion_ratio, str) and ',' in intrusion_ratio:
intrusion_ratio = [float(n) for n in intrusion_ratio.split(',')]
else:
intrusion_ratio = float(intrusion_ratio)
_t_sets = TrainingSets(path_ocr, path_img, output_dir=output_dir)
prefix_output = args.prefix_output
if prefix_output:
_t_sets.pair_prefix = prefix_output
res = _t_sets.create(
min_chars=min_chars,
summary=do_summary,
reorder=do_reorder,
intrusion_ratio=intrusion_ratio,
rotation_threshold=rotation_thresh,
binarize=do_binarize,
sanitize=do_opt,
padding=padding)
print(f"[DONE ] got '{len(res)}' pairs from '{path_ocr}'"
f" and '{path_img}' in '{output_dir}', better review")


def _run_dir(args):
path_ocr_dir = args.data
path_img_dir = args.image
_all_ocrs = [os.path.join(path_ocr_dir, _f)
for _f in os.listdir(path_ocr_dir)
if str(_f).endswith('.xml')]
print(f"[DEBUG] found total {len(_all_ocrs)} in {path_ocr_dir} and sub_dirs")
for _an_ocr in _all_ocrs:
_img_match = __get_image(path_img_dir, Path(_an_ocr).stem)
if _img_match:
args.data = _an_ocr
args.image = _img_match
_run_single_page(args)
_all_ocrs.remove(_an_ocr)
print(f"[INFO] missed {len(_all_ocrs)} in {path_img_dir}")


def __get_image(path_image_dir, label):
_all_imgs = [os.path.join(path_image_dir, _f)
for _f in os.listdir(path_image_dir)
if __has_image_ext(_f) and Path(_f).stem == label]
if not _all_imgs:
return None
if len(_all_imgs) > 1:
raise RuntimeError(f"Invalid image match {_all_imgs} for {label}")
return _all_imgs[0]


def __has_image_ext(file_name:str) -> bool:
_ext:str = Path(file_name).suffix
return _ext in ['.jpg', '.tif','.png']



########
# MAIN #
########
def main():
PARSER = argparse.ArgumentParser(description="generate pairs of textlines and image frames from existing OCR and image data")
PARSER: ArgumentParser = ArgumentParser(description="generate pairs of textlines and image frames from existing OCR and image data")
PARSER.add_argument(
"data",
type=str,
Expand All @@ -34,9 +114,13 @@ def main():
help="path to local image file tif|jpg|png corresponding to ocr. (default: read from OCR-Data)")
PARSER.add_argument(
"-o",
"--output_dir",
default=DEFAULT_OUTDIR_PREFIX,
help=f"output directory, re-created if already exists. (default: <script-dir>/<{DEFAULT_OUTDIR_PREFIX}>)")
PARSER.add_argument(
"--prefix-output",
required=False,
help=f"optional: output directory, re-created if already exists. (default: <script-dir>/<{DEFAULT_OUTDIR_PREFIX}>)")
help="optional: prefix each pair using this arg. (default: '')")
PARSER.add_argument(
"-m",
"--minchars",
Expand All @@ -50,7 +134,7 @@ def main():
required=False,
action='store_true',
default=DEFAULT_USE_SUMMARY,
help=f"optional: print all lines in additional file (default: {DEFAULT_USE_SUMMARY}, pattern: <default-output-dir>{SUMMARY_SUFFIX})")
help=f"optional: print all lines in additional file (default: {DEFAULT_USE_SUMMARY}, pattern: <default-output-dir>{SUFFIX_SUMMARY})")
PARSER.add_argument(
"-r",
"--reorder",
Expand Down Expand Up @@ -90,42 +174,17 @@ def main():
default=DEFAULT_PADDING,
help=f"optional: additional padding for existing textline image (default: {DEFAULT_PADDING})")

ARGS = PARSER.parse_args()
ARGS: Namespace = PARSER.parse_args()
print(f"[DEBUG] {os.path.basename(__file__)} using args: {ARGS}")
PATH_OCR = ARGS.data
PATH_IMG = ARGS.image
OUTPUT_PREFIX = ARGS.prefix_output
MIN_CHARS = ARGS.minchars
SUMMARY = ARGS.summary
REORDER = ARGS.reorder
BINARIZE = ARGS.binarize
SANITIZE = ARGS.sanitize
INTR_RATIO = ARGS.intrusion_ratio
if isinstance(INTR_RATIO, str) and ',' in INTR_RATIO:
INTR_RATIO = [float(n) for n in INTR_RATIO.split(',')]
else:
INTR_RATIO = float(INTR_RATIO)
ROTA_THRESH = ARGS.rotation_threshold
PADDING = ARGS.padding

if os.path.isfile(PATH_OCR) and os.path.isfile(PATH_IMG):
print(f"[INFO ] generate trainingsets from single file '{PATH_OCR}'")
print(f"[DEBUG] args: {ARGS}")
TRAINING_DATA = TrainingSets(PATH_OCR, PATH_IMG)
RESULT = TRAINING_DATA.create(
output_prefix=OUTPUT_PREFIX,
min_chars=MIN_CHARS,
summary=SUMMARY,
reorder=REORDER,
intrusion_ratio=INTR_RATIO,
rotation_threshold=ROTA_THRESH,
binarize=BINARIZE,
sanitize=SANITIZE,
padding=PADDING)
print(f"[DONE ] got '{len(RESULT)}' pairs from '{PATH_OCR}'"
f" and '{PATH_IMG}' in '{TRAINING_DATA.label}', please review")
# if os.path.isdir(PATH_OCR) and os.path.isdir(PATH_IMG):
# TODO handle lists of inputs
# print(f"[INFO ] inspect OCR-dir '{PATH_OCR}' and image dir '{PATH_IMG}")
_run_single_page(ARGS)
elif os.path.isdir(PATH_OCR) and os.path.isdir(PATH_IMG):
_run_dir(ARGS)
print(f"[INFO ] inspect OCR-dir '{PATH_OCR}' and image dir '{PATH_IMG}")
else:
print(f"[ERROR ] invalid OCR '{PATH_OCR}' or Image '{PATH_IMG}'!")

Expand Down
4 changes: 3 additions & 1 deletion tests/test_generate_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def test_create_sets_from_alto_and_tif(fixture_newspaper_p512):
path_items = os.listdir(_output_dir)
tifs = [tif for tif in path_items if str(tif).endswith(".tif")]
assert len(tifs) == 225
# please no *.gt.tif !!
assert not [tif for tif in path_items if str(tif).endswith(".gt.tif")]
lines = [txt for txt in path_items if str(txt).endswith(GT_SUFFIX)]

# one more txt since summery
Expand Down Expand Up @@ -199,7 +201,7 @@ def test_create_sets_from_page2013_and_jpg(fixture_page2013_jpg):

# assert
assert len(data) == 32
_output_dir = os.path.join(path_input_dir, f'page{OCR_TRANSK}')
_output_dir = os.path.join(path_input_dir, f'{OCR_TRANSK}')
path_items = os.listdir(_output_dir)
assert len([tif for tif in path_items if str(tif).endswith(".tif")]) == 32
txt_files = sorted(
Expand Down

0 comments on commit 90ccf29

Please sign in to comment.