diff --git a/conda/meta.yaml b/conda/meta.yaml index f68acffae..7380314bd 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -32,6 +32,7 @@ requirements: - pyarrow - pytorch-lightning~=2.0 - torchmetrics>=0.10.0 + - conda-forge::threadpoolctl~=3.2.0 - albumentations - rich about: diff --git a/docs/api_docs.rst b/docs/api_docs.rst index 46379f2b8..cb85ff91f 100644 --- a/docs/api_docs.rst +++ b/docs/api_docs.rst @@ -2,8 +2,11 @@ API Reference ************* +Segmentation +============ + kraken.blla module -================== +------------------ .. note:: @@ -14,7 +17,7 @@ kraken.blla module .. autoapifunction:: kraken.blla.segment kraken.pageseg module -===================== +--------------------- .. note:: @@ -24,22 +27,22 @@ kraken.pageseg module .. autoapifunction:: kraken.pageseg.segment -kraken.rpred module -=================== +Recognition +=========== -.. autoapifunction:: kraken.rpred.bidi_record +kraken.rpred module +------------------- .. autoapiclass:: kraken.rpred.mm_rpred :members: -.. autoapiclass:: kraken.rpred.ocr_record - :members: - .. autoapifunction:: kraken.rpred.rpred +Serialization +============= kraken.serialization module -=========================== +--------------------------- .. autoapifunction:: kraken.serialization.render_report @@ -47,127 +50,118 @@ kraken.serialization module .. autoapifunction:: kraken.serialization.serialize_segmentation -kraken.lib.models module -======================== +Default templates +----------------- -.. autoapiclass:: kraken.lib.models.TorchSeqRecognizer - :members: +ALTO 4.4 +^^^^^^^^ -.. autoapifunction:: kraken.lib.models.load_any +.. literalinclude:: ../../templates/alto + :language: xml+jinja -kraken.lib.vgsl module -====================== +PageXML +^^^^^^^ -.. autoapiclass:: kraken.lib.vgsl.TorchVGSLModel - :members: +.. literalinclude:: ../../templates/alto + :language: xml+jinja -kraken.lib.xml module -===================== +hOCR +^^^^ -.. autoapifunction:: kraken.lib.xml.parse_xml +.. literalinclude:: ../../templates/alto + :language: xml+jinja -.. autoapifunction:: kraken.lib.xml.parse_page +ABBYY XML +^^^^^^^^^ -.. autoapifunction:: kraken.lib.xml.parse_alto +.. literalinclude:: ../../templates/abbyyxml + :language: xml+jinja + +Containers and Helpers +====================== kraken.lib.codec module -======================= +----------------------- .. autoapiclass:: kraken.lib.codec.PytorchCodec :members: -kraken.lib.train module -======================= +kraken.containers module +------------------------ -Training Schedulers -------------------- +.. autoapiclass:: kraken.containers.Segmentation + :members: -.. autoapiclass:: kraken.lib.train.TrainScheduler - :members: +.. autoapiclass:: kraken.containers.BaselineLine + :members: -.. autoapiclass:: kraken.lib.train.annealing_step - :members: +.. autoapiclass:: kraken.containers.BBoxLine + :members: -.. autoapiclass:: kraken.lib.train.annealing_const - :members: +.. autoapiclass:: kraken.containers.ocr_record + :members: -.. autoapiclass:: kraken.lib.train.annealing_exponential - :members: +.. autoapiclass:: kraken.containers.BaselineOCRRecord + :members: -.. autoapiclass:: kraken.lib.train.annealing_reduceonplateau - :members: +.. autoapiclass:: kraken.containers.BBoxOCRRecord + :members: -.. autoapiclass:: kraken.lib.train.annealing_cosine - :members: +.. autoapiclass:: kraken.containers.ProcessingStep + :members: -.. autoapiclass:: kraken.lib.train.annealing_onecycle - :members: +kraken.lib.ctc_decoder +---------------------- -Training Stoppers ------------------ +.. autoapifunction:: kraken.lib.ctc_decoder.beam_decoder -.. autoapiclass:: kraken.lib.train.TrainStopper - :members: +.. autoapifunction:: kraken.lib.ctc_decoder.greedy_decoder -.. autoapiclass:: kraken.lib.train.EarlyStopping - :members: +.. autoapifunction:: kraken.lib.ctc_decoder.blank_threshold_decoder -.. autoapiclass:: kraken.lib.train.EpochStopping - :members: +kraken.lib.exceptions +--------------------- -.. autoapiclass:: kraken.lib.train.NoStopping +.. autoapiclass:: kraken.lib.exceptions.KrakenCodecException :members: -Loss and Evaluation Functions ------------------------------ - -.. autoapifunction:: kraken.lib.train.recognition_loss_fn - -.. autoapifunction:: kraken.lib.train.baseline_label_loss_fn - -.. autoapifunction:: kraken.lib.train.recognition_evaluator_fn - -.. autoapifunction:: kraken.lib.train.baseline_label_evaluator_fn - -Trainer -------- - -.. autoapiclass:: kraken.lib.train.KrakenTrainer +.. autoapiclass:: kraken.lib.exceptions.KrakenStopTrainingException :members: +.. autoapiclass:: kraken.lib.exceptions.KrakenEncodeException + :members: -kraken.lib.dataset module -========================= - -Datasets --------- +.. autoapiclass:: kraken.lib.exceptions.KrakenRecordException + :members: -.. autoapiclass:: kraken.lib.dataset.BaselineSet +.. autoapiclass:: kraken.lib.exceptions.KrakenInvalidModelException :members: -.. autoapiclass:: kraken.lib.dataset.PolygonGTDataset +.. autoapiclass:: kraken.lib.exceptions.KrakenInputException :members: -.. autoapiclass:: kraken.lib.dataset.GroundTruthDataset +.. autoapiclass:: kraken.lib.exceptions.KrakenRepoException :members: -Helpers -------- +.. autoapiclass:: kraken.lib.exceptions.KrakenCairoSurfaceException + :members: -.. autoapifunction:: kraken.lib.dataset.compute_error +kraken.lib.models module +------------------------ -.. autoapifunction:: kraken.lib.dataset.preparse_xml_data +.. autoapiclass:: kraken.lib.models.TorchSeqRecognizer + :members: -.. autoapifunction:: kraken.lib.dataset.generate_input_transforms +.. autoapifunction:: kraken.lib.models.load_any kraken.lib.segmentation module ------------------------------ .. autoapifunction:: kraken.lib.segmentation.reading_order -.. autoapifunction:: kraken.lib.segmentation.polygonal_reading_order +.. autoapifunction:: kraken.lib.segmentation.neural_reading_order -.. autoapifunction:: kraken.lib.segmentation.denoising_hysteresis_thresh +.. autoapifunction:: kraken.lib.segmentation.polygonal_reading_order .. autoapifunction:: kraken.lib.segmentation.vectorize_lines @@ -181,43 +175,82 @@ kraken.lib.segmentation module .. autoapifunction:: kraken.lib.segmentation.extract_polygons +kraken.lib.vgsl module +---------------------- -kraken.lib.ctc_decoder -====================== +.. autoapiclass:: kraken.lib.vgsl.TorchVGSLModel + :members: -.. autoapifunction:: kraken.lib.ctc_decoder.beam_decoder +kraken.lib.xml module +--------------------- -.. autoapifunction:: kraken.lib.ctc_decoder.greedy_decoder +.. autoapiclass:: kraken.lib.xml.XMLPage -.. autoapifunction:: kraken.lib.ctc_decoder.blank_threshold_decoder +Training +======== -kraken.lib.exceptions -===================== +kraken.lib.train module +----------------------- -.. autoapiclass:: kraken.lib.exceptions.KrakenCodecException - :members: +Loss and Evaluation Functions +----------------------------- -.. autoapiclass:: kraken.lib.exceptions.KrakenStopTrainingException +.. autoapifunction:: kraken.lib.train.recognition_loss_fn + +.. autoapifunction:: kraken.lib.train.baseline_label_loss_fn + +.. autoapifunction:: kraken.lib.train.recognition_evaluator_fn + +.. autoapifunction:: kraken.lib.train.baseline_label_evaluator_fn + +Trainer +------- + +.. autoapiclass:: kraken.lib.train.KrakenTrainer :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenEncodeException + +kraken.lib.dataset module +------------------------- + +Recognition datasets +^^^^^^^^^^^^^^^^^^^^ + +.. autoapiclass:: kraken.lib.dataset.ArrowIPCRecognitionDataset :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenRecordException +.. autoapiclass:: kraken.lib.dataset.BaselineSet :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenInvalidModelException +.. autoapiclass:: kraken.lib.dataset.GroundTruthDataset :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenInputException +Segmentation datasets +^^^^^^^^^^^^^^^^^^^^^ + +.. autoapiclass:: kraken.lib.dataset.PolygonGTDataset :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenRepoException +Reading order datasets +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoapiclass:: kraken.lib.dataset.PairWiseROSet :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenCairoSurfaceException +.. autoapiclass:: kraken.lib.dataset.PageWiseROSet :members: +Helpers +^^^^^^^ + +.. autoapiclass:: kraken.lib.dataset.ImageInputTransforms + :members: + +.. autoapifunction:: kraken.lib.dataset.collate_sequences + +.. autoapifunction:: kraken.lib.dataset.global_align + +.. autoapifunction:: kraken.lib.dataset.compute_confusions Legacy modules ============== diff --git a/docs/ketos.rst b/docs/ketos.rst index c3bd2926a..c96481390 100644 --- a/docs/ketos.rst +++ b/docs/ketos.rst @@ -142,7 +142,7 @@ option action -F, \--savefreq Model save frequency in epochs during training -q, \--quit Stop condition for training. Set to `early` - for early stopping (default) or `dumb` for fixed + for early stopping (default) or `fixed` for fixed number of epochs. -N, \--epochs Number of epochs to train for. \--min-epochs Minimum number of epochs to train for when using early stopping. diff --git a/environment.yml b/environment.yml index 9344af7a6..05112f876 100644 --- a/environment.yml +++ b/environment.yml @@ -24,6 +24,7 @@ dependencies: - pyarrow - conda-forge::pytorch-lightning~=2.0.0 - conda-forge::torchmetrics>=0.10.0 + - conda-forge::threadpoolctl~=3.2 - pip - albumentations - rich diff --git a/environment_cuda.yml b/environment_cuda.yml index 49c1faa70..8464004b6 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -25,6 +25,7 @@ dependencies: - pyarrow - conda-forge::pytorch-lightning~=2.0.0 - conda-forge::torchmetrics>=0.10.0 + - conda-forge::threadpoolctl~=3.2 - pip - albumentations - rich diff --git a/kraken/align.py b/kraken/align.py index 961585221..ff76a1d43 100644 --- a/kraken/align.py +++ b/kraken/align.py @@ -23,6 +23,7 @@ """ import torch import logging +import dataclasses import numpy as np from PIL import Image @@ -32,7 +33,9 @@ from typing import List, Dict, Any, Optional, Literal from kraken import rpred +from kraken.containers import Segmentation, BaselineOCRRecord from kraken.lib.codec import PytorchCodec +from kraken.lib.xml import XMLPage from kraken.lib.models import TorchSeqRecognizer from kraken.lib.exceptions import KrakenInputException, KrakenEncodeException from kraken.lib.segmentation import compute_polygon_section @@ -40,7 +43,7 @@ logger = logging.getLogger('kraken') -def forced_align(doc: Dict[str, Any], model: TorchSeqRecognizer, base_dir: Optional[Literal['L', 'R']] = None) -> List[rpred.ocr_record]: +def forced_align(doc: Segmentation, model: TorchSeqRecognizer, base_dir: Optional[Literal['L', 'R']] = None) -> Segmentation: """ Performs a forced character alignment of text with recognition model output activations. @@ -50,28 +53,26 @@ def forced_align(doc: Dict[str, Any], model: TorchSeqRecognizer, base_dir: Optio model: Recognition model to use for alignment. Returns: - A list of kraken.rpred.ocr_record. + A Segmentation object where the record's contain the aligned text. """ - im = Image.open(doc['image']) + im = Image.open(doc.imagename) predictor = rpred.rpred(model, im, doc) - if 'type' in predictor.bounds and predictor.bounds['type'] == 'baselines': - rec_class = rpred.BaselineOCRRecord records = [] # enable training mode in last layer to get log_softmax output model.nn.nn[-1].training = True - for idx, line in enumerate(doc['lines']): + for idx, line in enumerate(doc.lines): # convert text to display order - do_text = get_display(line['text'], base_dir=base_dir) + do_text = get_display(line.text, base_dir=base_dir) # encode into labels, ignoring unencodable sequences labels = model.codec.encode(do_text).long() next(predictor) if model.outputs.shape[2] < 2*len(labels): logger.warning(f'Could not align line {idx}. Output sequence length {model.outputs.shape[2]} < ' - f'{2*len(labels)} (length of "{line["text"]}" after encoding).') - records.append(rpred.BaselineOCRRecord('', [], [], line)) + f'{2*len(labels)} (length of "{line.text}" after encoding).') + records.append(BaselineOCRRecord('', [], [], line)) continue emission = torch.tensor(model.outputs).squeeze().T trellis = get_trellis(emission, labels) @@ -85,8 +86,9 @@ def forced_align(doc: Dict[str, Any], model: TorchSeqRecognizer, base_dir: Optio pos.append((predictor._scale_val(seg.start, 0, predictor.box.size[0]), predictor._scale_val(seg.end, 0, predictor.box.size[0]))) conf.append(seg.score) - records.append(rpred.BaselineOCRRecord(pred, pos, conf, line, display_order=True)) - return records + records.append(BaselineOCRRecord(pred, pos, conf, line, display_order=True)) + return dataclasses.replace(doc, lines=records) + """ Copied from the forced alignment with Wav2Vec2 tutorial of pytorch available diff --git a/kraken/blla.py b/kraken/blla.py index 84d7fafca..050a2abb7 100644 --- a/kraken/blla.py +++ b/kraken/blla.py @@ -21,6 +21,7 @@ """ import PIL +import uuid import torch import logging import numpy as np @@ -29,18 +30,21 @@ import torch.nn.functional as F import torchvision.transforms as tf -from typing import Optional, Dict, Callable, Union, List, Any, Tuple +from typing import Optional, Dict, Callable, Union, List, Any, Tuple, Literal from scipy.ndimage import gaussian_filter from skimage.filters import sobel from kraken.lib import vgsl, dataset +from kraken.containers import Region, Segmentation, BaselineLine from kraken.lib.util import is_bitonal, get_im_str from kraken.lib.exceptions import KrakenInputException, KrakenInvalidModelException from kraken.lib.segmentation import (polygonal_reading_order, + neural_reading_order, vectorize_lines, vectorize_regions, scale_polygonal_lines, calculate_polygonal_environment, + is_in_region, scale_regions) __all__ = ['segment'] @@ -73,8 +77,6 @@ def compute_segmentation_map(im: PIL.Image.Image, Raises: KrakenInputException: When given an invalid mask. """ - im_str = get_im_str(im) - logger.info(f'Segmenting {im_str}') if model.input[1] == 1 and model.one_channel_mode == '1' and not is_bitonal(im): logger.warning('Running binary model on non-binary input image ' @@ -134,7 +136,7 @@ def compute_segmentation_map(im: PIL.Image.Image, 'scal_im': scal_im} -def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs) -> Dict[str, List[List[Tuple[int, int]]]]: +def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs) -> Dict[str, List[Region]]: """ Computes regions from a stack of heatmaps, a class mapping, and scaling factor. @@ -154,8 +156,8 @@ def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs) -> for region_type, idx in cls_map['regions'].items(): logger.debug(f'Vectorizing regions of type {region_type}') regions[region_type] = vectorize_regions(heatmap[idx]) - for reg_id, regs in regions.items(): - regions[reg_id] = scale_regions(regs, scale) + for reg_type, regs in regions.items(): + regions[reg_type] = [Region(id=str(uuid.uuid4()), boundary=x, tags={'type': reg_type}) for x in scale_regions(regs, scale)] return regions @@ -163,7 +165,6 @@ def vec_lines(heatmap: torch.Tensor, cls_map: Dict[str, Dict[str, int]], scale: float, text_direction: str = 'horizontal-lr', - reading_order_fn: Callable = polygonal_reading_order, regions: List[np.ndarray] = None, scal_im: np.ndarray = None, suppl_obj: List[np.ndarray] = None, @@ -181,7 +182,6 @@ def vec_lines(heatmap: torch.Tensor, scale: Scaling factor between heatmap and unscaled input image. text_direction: Text directions used as hints in the reading order algorithm. - reading_order_fn: Reading order calculation function. regions: Regions to be used as boundaries during polygonization and atomic blocks during reading order determination for lines contained within. @@ -206,6 +206,7 @@ def vec_lines(heatmap: torch.Tensor, ... ] """ + st_sep = cls_map['aux']['_start_separator'] end_sep = cls_map['aux']['_end_separator'] @@ -222,39 +223,34 @@ def vec_lines(heatmap: torch.Tensor, reg_pols = [geom.Polygon(x) for x in regions] for bl_idx in range(len(baselines)): bl = baselines[bl_idx] - mid_point = geom.LineString(bl[1]).interpolate(0.5, normalized=True) - + bl_ls = geom.LineString(bl[1]) suppl_obj = [x[1] for x in baselines[:bl_idx] + baselines[bl_idx+1:]] for reg_idx, reg_pol in enumerate(reg_pols): - if reg_pol.contains(mid_point): + if is_in_region(bl_ls, reg_pol): suppl_obj.append(regions[reg_idx]) - - pol = calculate_polygonal_environment( - baselines=[bl[1]], - im_feats=im_feats, - suppl_obj=suppl_obj, - topline=topline, - raise_on_error=raise_on_error - ) + pol = calculate_polygonal_environment(baselines=[bl[1]], + im_feats=im_feats, + suppl_obj=suppl_obj, + topline=topline, + raise_on_error=raise_on_error) if pol[0] is not None: lines.append((bl[0], bl[1], pol[0])) logger.debug('Scaling vectorized lines') sc = scale_polygonal_lines([x[1:] for x in lines], scale) + lines = list(zip([x[0] for x in lines], [x[0] for x in sc], [x[1] for x in sc])) - logger.debug('Reordering baselines') - lines = reading_order_fn(lines=lines, regions=regions, text_direction=text_direction[-2:]) return [{'tags': {'type': bl_type}, 'baseline': bl, 'boundary': pl} for bl_type, bl, pl in lines] def segment(im: PIL.Image.Image, - text_direction: str = 'horizontal-lr', + text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] = 'horizontal-lr', mask: Optional[np.ndarray] = None, reading_order_fn: Callable = polygonal_reading_order, model: Union[List[vgsl.TorchVGSLModel], vgsl.TorchVGSLModel] = None, device: str = 'cpu', raise_on_error: bool = False, - autocast: bool = False) -> Dict[str, Any]: + autocast: bool = False) -> Segmentation: r""" Segments a page into text lines using the baseline segmenter. @@ -280,31 +276,37 @@ def segment(im: PIL.Image.Image, autocast: Runs the model with automatic mixed precision Returns: - A dictionary containing the text direction and under the key 'lines' a - list of reading order sorted baselines (polylines) and their respective - polygonal boundaries. The last and first point of each boundary polygon - are connected. + A :class:`kraken.containers.Segmentation` class containing reading order + sorted baselines (polylines) and their respective polygonal boundaries + as :class:`kraken.containers.BaselineLine` records. + The format of the line and region records is shown below. The last and + first point of each boundary polygon are connected. .. code-block:: :force: - {'text_direction': '$dir', - 'type': 'baseline', - 'lines': [ - {'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]}, - {'baseline': [[x0, ...]], 'boundary': [[x0, ...]]} - ] - 'regions': [ - {'region': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'type': 'image'}, - {'region': [[x0, ...]], 'type': 'text'} - ] - } + 'lines': [ + {'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]}, + {'baseline': [[x0, ...]], 'boundary': [[x0, ...]]} + ] + 'regions': [ + {'region': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'type': 'image'}, + {'region': [[x0, ...]], 'type': 'text'} + ] Raises: KrakenInvalidModelException: if the given model is not a valid segmentation model. KrakenInputException: if the mask is not bitonal or does not match the image size. + + Notes: + Multi-model operation is most useful for combining one or more region + detection models and one text line model. Detected lines from all + models are simply combined without any merging or duplicate detection + so the chance of the same line appearing multiple times in the output + are high. In addition, neural reading order determination is disabled + when more than one model outputs lines. """ if model is None: logger.info('No segmentation model given. Loading default model.') @@ -322,6 +324,12 @@ def segment(im: PIL.Image.Image, im_str = get_im_str(im) logger.info(f'Segmenting {im_str}') + lines = [] + order = None + regions = {} + multi_lines = False + # flag to indicate that multiple models produced line output -> disable + # neural reading order for net in model: if 'topline' in net.user_metadata: loc = {None: 'center', @@ -329,32 +337,82 @@ def segment(im: PIL.Image.Image, False: 'bottom'}[net.user_metadata['topline']] logger.debug(f'Baseline location: {loc}') rets = compute_segmentation_map(im, mask, net, device, autocast=autocast) - regions = vec_regions(**rets) + _regions = vec_regions(**rets) + for reg_key, reg_val in vec_regions(**rets).items(): + if reg_key not in regions: + regions[reg_key] = [] + regions[reg_key].extend(reg_val) + # flatten regions for line ordering/fetch bounding regions line_regs = [] suppl_obj = [] - for cls, regs in regions.items(): + for cls, regs in _regions.items(): line_regs.extend(regs) if rets['bounding_regions'] is not None and cls in rets['bounding_regions']: suppl_obj.extend(regs) # convert back to net scale - suppl_obj = scale_regions(suppl_obj, 1/rets['scale']) - line_regs = scale_regions(line_regs, 1/rets['scale']) - lines = vec_lines(**rets, - regions=line_regs, - reading_order_fn=reading_order_fn, - text_direction=text_direction, - suppl_obj=suppl_obj, - topline=net.user_metadata['topline'] if 'topline' in net.user_metadata else False, - raise_on_error=raise_on_error) + suppl_obj = scale_regions([x.boundary for x in suppl_obj], 1/rets['scale']) + line_regs = scale_regions([x.boundary for x in line_regs], 1/rets['scale']) + + _lines = vec_lines(**rets, + regions=line_regs, + text_direction=text_direction, + suppl_obj=suppl_obj, + topline=net.user_metadata['topline'] if 'topline' in net.user_metadata else False, + raise_on_error=raise_on_error) + + if 'ro_model' in net.aux_layers: + logger.info(f'Using reading order model found in segmentation model {net}.') + _order = neural_reading_order(lines=_lines, + regions=_regions, + text_direction=text_direction[-2:], + model=net.aux_layers['ro_model'], + im_size=im.size, + class_mapping=net.user_metadata['ro_class_mapping']) + else: + _order = None + + if _lines and lines or multi_lines: + multi_lines = True + order = None + logger.warning('Multiple models produced line output. This is ' + 'likely unintended. Suppressing neural reading ' + 'order.') + else: + order = _order + + lines.extend(_lines) if len(rets['cls_map']['baselines']) > 1: script_detection = True else: script_detection = False - return {'text_direction': text_direction, - 'type': 'baselines', - 'lines': lines, - 'regions': regions, - 'script_detection': script_detection} + # create objects and assign IDs + blls = [] + reg_idx = 0 + _shp_regs = {} + for reg_type, rgs in regions.items(): + for reg in rgs: + _shp_regs[reg.id] = geom.Polygon(reg.boundary) + + # reorder lines + logger.debug(f'Reordering baselines with main RO function {reading_order_fn}.') + basic_lo = reading_order_fn(lines=lines, regions=_shp_regs.values(), text_direction=text_direction[-2:]) + lines = [lines[idx] for idx in basic_lo] + + for line in lines: + line_regs = [] + for reg_id, reg in _shp_regs.items(): + line_ls = geom.LineString(line['baseline']) + if is_in_region(line_ls, reg): + line_regs.append(reg_id) + blls.append(BaselineLine(id=str(uuid.uuid4()), baseline=line['baseline'], boundary=line['boundary'], tags=line['tags'], regions=line_regs)) + + return Segmentation(text_direction=text_direction, + imagename=getattr(im, 'filename', None), + type='baselines', + lines=blls, + regions=regions, + script_detection=script_detection, + line_orders=[order]) diff --git a/kraken/containers.py b/kraken/containers.py new file mode 100644 index 000000000..c58ef9c10 --- /dev/null +++ b/kraken/containers.py @@ -0,0 +1,588 @@ + +import PIL.Image +import numpy as np +import bidi.algorithm as bd + +from os import PathLike +from typing import Literal, List, Dict, Union, Optional, Tuple +from dataclasses import dataclass, asdict +from abc import ABC, abstractmethod + +from kraken.lib.segmentation import compute_polygon_section + +__all__ = ['BaselineLine', + 'BBoxLine', + 'Segmentation', + 'ocr_record', + 'BaselineOCRRecord', + 'BBoxOCRRecord', + 'ProcessingStep'] + + +@dataclass +class ProcessingStep: + """ + A processing step in the recognition pipeline. + + Attributes: + id: Unique identifier + category: Category of processing step that has been performat. + description: Natural-language description of the process. + settings: Dict describing the parameters of the processing step. + """ + id: str + category: Literal['preprocessing', 'processing', 'postprocessing'] + description: str + settings: Dict[str, Union[Dict, str, float, int, bool]] + + +@dataclass +class BaselineLine: + """ + Baseline-type line record. + + A container class for a single line in baseline + bounding polygon format, + optionally containing a transcription, tags, or associated regions. + + Attributes: + id: Unique identifier + baseline: List of tuples `(x_n, y_n)` defining the baseline. + boundary: List of tuples `(x_n, y_n)` defining the bounding polygon of + the line. The first and last points should be identical. + text: Transcription of this line. + base_dir: An optional string defining the base direction (also called + paragraph direction) for the BiDi algorithm. Valid values are + 'L' or 'R'. If None is given the default auto-resolution will + be used. + imagename: Path to the image associated with the line. + tags: A dict mapping types to values. + split: Defines whether this line is in the `train`, `validation`, or + `test` set during training. + regions: A list of identifiers of regions the line is associated with. + """ + id: str + baseline: List[Tuple[int, int]] + boundary: List[Tuple[int, int]] + text: Optional[str] = None + base_dir: Optional[Literal['L', 'R']] = None + type: str = 'baselines' + imagename: Optional[Union[str, PathLike]] = None + tags: Optional[Dict[str, str]] = None + split: Optional[Literal['train', 'validation', 'test']] = None + regions: Optional[List[str]] = None + + +@dataclass +class BBoxLine: + """ + Bounding box-type line record. + + A container class for a single line in axis-aligned bounding box format, + optionally containing a transcription, tags, or associated regions. + + Attributes: + id: Unique identifier + bbox: Tuple in form `((x0, y0), (x1, y0), (x1, y1), (x0, y1))` defining + the bounding box. + text: Transcription of this line. + base_dir: An optional string defining the base direction (also called + paragraph direction) for the BiDi algorithm. Valid values are + 'L' or 'R'. If None is given the default auto-resolution will + be used. + imagename: Path to the image associated with the line.. + tags: A dict mapping types to values. + split: Defines whether this line is in the `train`, `validation`, or + `test` set during training. + regions: A list of identifiers of regions the line is associated with. + text_direction: Sets the principal orientation (of the line) and + reading direction (of the document). + """ + id: str + bbox: Tuple[Tuple[int, int], + Tuple[int, int], + Tuple[int, int], + Tuple[int, int]] + text: Optional[str] = None + base_dir: Optional[Literal['L', 'R']] = None + type: str = 'bbox' + imagename: Optional[Union[str, PathLike]] = None + tags: Optional[Dict[str, str]] = None + split: Optional[Literal['train', 'validation', 'test']] = None + regions: Optional[List[str]] = None + text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] = 'horizontal-lr' + + +@dataclass +class Region: + """ + Container class of a single polygonal region. + + Attributes: + id: Unique identifier + boundary: List of tuples `(x_n, y_n)` defining the bounding polygon of + the region. The first and last points should be identical. + imagename: Path to the image associated with the region. + tags: A dict mapping types to values. + """ + id: str + boundary: List[Tuple[int, int]] + imagename: Optional[Union[str, PathLike]] = None + tags: Optional[Dict[str, str]] = None + + +@dataclass +class Segmentation: + """ + A container class for segmentation or recognition results. + + In order to allow easy JSON de-/serialization, nested classes for lines + (BaselineLine/BBoxLine) and regions (Region) are reinstantiated from their + dictionaries. + + Attributes: + type: Field indicating if baselines + (:class:`kraken.containers.BaselineLine`) or bbox + (:class:`kraken.containers.BBoxLine`) line records are in the + segmentation. + imagename: Path to the image associated with the segmentation. + text_direction: Sets the principal orientation (of the line), i.e. + horizontal/vertical, and reading direction (of the + document), i.e. lr/rl. + script_detection: Flag indicating if the line records have tags. + lines: List of line records. Records are expected to be in a valid + reading order. + regions: Dict mapping types to lists of regions. + line_orders: List of alternative reading orders for the segmentation. + Each reading order is a list of line indices. + """ + type: Literal['baselines', 'bbox'] + imagename: Union[str, PathLike] + text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] + script_detection: bool + lines: List[Union[BaselineLine, BBoxLine]] + regions: Dict[str, List[Region]] + line_orders: Optional[List[List[int]]] = None + + def __post_init__(self): + if len(self.lines) and not isinstance(self.lines[0], BBoxLine) and not isinstance(self.lines[0], BaselineLine): + line_cls = BBoxLine if self.type == 'bbox' else BaselineLine + self.lines = [line_cls(**line) for line in self.lines] + if len(self.regions) and not isinstance(next(iter(self.regions.values()))[0], Region): + regs = {} + for k, v in self.regions.items(): + regs[k] = [Region(**reg) for reg in v] + self.regions = regs + + +class ocr_record(ABC): + """ + A record object containing the recognition result of a single line + """ + base_dir = None + + def __init__(self, + prediction: str, + cuts: List[Union[Tuple[int, int], Tuple[Tuple[int, int], + Tuple[int, int], + Tuple[int, int], + Tuple[int, int]]]], + confidences: List[float], + display_order: bool = True) -> None: + self._prediction = prediction + self._cuts = cuts + self._confidences = confidences + self._display_order = display_order + + @property + @abstractmethod + def type(self): + pass + + def __len__(self) -> int: + return len(self._prediction) + + def __str__(self) -> str: + return self._prediction + + @property + def prediction(self) -> str: + return self._prediction + + @property + def cuts(self) -> List: + return self._cuts + + @property + def confidences(self) -> List[float]: + return self._confidences + + def __iter__(self): + self.idx = -1 + return self + + @abstractmethod + def __next__(self) -> Tuple[str, + Union[List[Tuple[int, int]], + Tuple[Tuple[int, int], + Tuple[int, int], + Tuple[int, int], + Tuple[int, int]]], + float]: + pass + + @abstractmethod + def __getitem__(self, key: Union[int, slice]): + pass + + @abstractmethod + def display_order(self, base_dir) -> 'ocr_record': + pass + + @abstractmethod + def logical_order(self, base_dir) -> 'ocr_record': + pass + + +class BaselineOCRRecord(ocr_record, BaselineLine): + """ + A record object containing the recognition result of a single line in + baseline format. + + Attributes: + type: 'baselines' to indicate a baseline record + prediction: The text predicted by the network as one continuous string. + cuts: The absolute bounding polygons for each code point in prediction + as a list of tuples [(x0, y0), (x1, y2), ...]. + confidences: A list of floats indicating the confidence value of each + code point. + base_dir: An optional string defining the base direction (also called + paragraph direction) for the BiDi algorithm. Valid values are + 'L' or 'R'. If None is given the default auto-resolution will + be used. + display_order: Flag indicating the order of the code points in the + prediction. In display order (`True`) the n-th code + point in the string corresponds to the n-th leftmost + code point, in logical order (`False`) the n-th code + point corresponds to the n-th read code point. See [UAX + #9](https://unicode.org/reports/tr9) for more details. + + Notes: + When slicing the record the behavior of the cuts is changed from + earlier versions of kraken. Instead of returning per-character bounding + polygons a single polygons section of the line bounding polygon + starting at the first and extending to the last code point emitted by + the network is returned. This aids numerical stability when computing + aggregated bounding polygons such as for words. Individual code point + bounding polygons are still accessible through the `cuts` attribute or + by iterating over the record code point by code point. + """ + type = 'baselines' + + def __init__(self, prediction: str, + cuts: List[Tuple[int, int]], + confidences: List[float], + line: BaselineLine, + base_dir: Optional[Literal['L', 'R']] = None, + display_order: bool = True) -> None: + if line.type != 'baselines': + raise TypeError('Invalid argument type (non-baseline line)') + BaselineLine.__init__(self, **asdict(line)) + self._line_base_dir = self.base_dir + self.base_dir = base_dir + ocr_record.__init__(self, prediction, cuts, confidences, display_order) + + def __repr__(self) -> str: + return f'pred: {self.prediction} baseline: {self.baseline} boundary: {self.boundary} confidences: {self.confidences}' + + def __next__(self) -> Tuple[str, int, float]: + if self.idx + 1 < len(self): + self.idx += 1 + return (self.prediction[self.idx], + compute_polygon_section(self.baseline, + self.boundary, + self.cuts[self.idx][0], + self.cuts[self.idx][1]), + self.confidences[self.idx]) + else: + raise StopIteration + + def _get_raw_item(self, key: int): + if key < 0: + key += len(self) + if key >= len(self): + raise IndexError('Index (%d) is out of range' % key) + return (self.prediction[key], + self._cuts[key], + self.confidences[key]) + + def __getitem__(self, key: Union[int, slice]): + if isinstance(key, slice): + recs = [self._get_raw_item(i) for i in range(*key.indices(len(self)))] + prediction = ''.join([x[0] for x in recs]) + flat_offsets = sum((tuple(x[1]) for x in recs), ()) + cut = compute_polygon_section(self.baseline, + self.boundary, + min(flat_offsets), + max(flat_offsets)) + confidence = np.mean([x[2] for x in recs]) + return (prediction, cut, confidence) + elif isinstance(key, int): + pred, cut, confidence = self._get_raw_item(key) + return (pred, + compute_polygon_section(self.baseline, self.boundary, cut[0], cut[1]), + confidence) + else: + raise TypeError('Invalid argument type') + + @property + def cuts(self) -> List[Tuple[int, int]]: + return tuple([compute_polygon_section(self.baseline, self.boundary, cut[0], cut[1]) for cut in self._cuts]) + + def logical_order(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BaselineOCRRecord': + """ + Returns the OCR record in Unicode logical order, i.e. in the order the + characters in the line would be read by a human. + + Args: + base_dir: An optional string defining the base direction (also + called paragraph direction) for the BiDi algorithm. Valid + values are 'L' or 'R'. If None is given the default + auto-resolution will be used. + """ + if self._display_order: + return self._reorder(base_dir) + else: + return self + + def display_order(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BaselineOCRRecord': + """ + Returns the OCR record in Unicode display order, i.e. ordered from left + to right inside the line. + + Args: + base_dir: An optional string defining the base direction (also + called paragraph direction) for the BiDi algorithm. Valid + values are 'L' or 'R'. If None is given the default + auto-resolution will be used. + """ + if self._display_order: + return self + else: + return self._reorder(base_dir) + + def _reorder(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BaselineOCRRecord': + """ + Reorder the record using the BiDi algorithm. + """ + storage = bd.get_empty_storage() + + if base_dir not in ('L', 'R'): + base_level = bd.get_base_level(self._prediction) + else: + base_level = {'L': 0, 'R': 1}[base_dir] + + storage['base_level'] = base_level + storage['base_dir'] = ('L', 'R')[base_level] + bd.get_embedding_levels(self._prediction, storage) + bd.explicit_embed_and_overrides(storage) + bd.resolve_weak_types(storage) + bd.resolve_neutral_types(storage, False) + bd.resolve_implicit_levels(storage, False) + for i, j in enumerate(zip(self._prediction, self._cuts, self._confidences)): + storage['chars'][i]['record'] = j + bd.reorder_resolved_levels(storage, False) + bd.apply_mirroring(storage, False) + prediction = '' + cuts = [] + confidences = [] + for ch in storage['chars']: + # code point may have been mirrored + prediction = prediction + ch['ch'] + cuts.append(ch['record'][1]) + confidences.append(ch['record'][2]) + line = BaselineLine(id=self.id, + baseline=self.baseline, + boundary=self.boundary, + text=self.text, + base_dir=self._line_base_dir, + image=self.image, + tags=self.tags, + split=self.split, + regions=self.regions) + rec = BaselineOCRRecord(prediction=prediction, + cuts=cuts, + confidences=confidences, + line=line, + base_dir=base_dir, + display_order=not self._display_order) + return rec + + +class BBoxOCRRecord(ocr_record, BBoxLine): + """ + A record object containing the recognition result of a single line in + bbox format. + + Attributes: + type: 'bbox' to indicate a bounding box record + prediction: The text predicted by the network as one continuous string. + cuts: The absolute bounding polygons for each code point in prediction + as a list of 4-tuples `((x0, y0), (x1, y0), (x1, y1), (x0, y1))`. + confidences: A list of floats indicating the confidence value of each + code point. + base_dir: An optional string defining the base direction (also called + paragraph direction) for the BiDi algorithm. Valid values are + 'L' or 'R'. If None is given the default auto-resolution will + be used. + display_order: Flag indicating the order of the code points in the + prediction. In display order (`True`) the n-th code + point in the string corresponds to the n-th leftmost + code point, in logical order (`False`) the n-th code + point corresponds to the n-th read code point. See [UAX + #9](https://unicode.org/reports/tr9) for more details. + + Notes: + When slicing the record the behavior of the cuts is changed from + earlier versions of kraken. Instead of returning per-character bounding + polygons a single polygons section of the line bounding polygon + starting at the first and extending to the last code point emitted by + the network is returned. This aids numerical stability when computing + aggregated bounding polygons such as for words. Individual code point + bounding polygons are still accessible through the `cuts` attribute or + by iterating over the record code point by code point. + """ + type = 'bbox' + + def __init__(self, + prediction: str, + cuts: List[Tuple[Tuple[int, int], + Tuple[int, int], + Tuple[int, int], + Tuple[int, int]]], + confidences: List[float], + line: BBoxLine, + base_dir: Optional[Literal['L', 'R']], + display_order: bool = True) -> None: + if line.type != 'bbox': + raise TypeError('Invalid argument type (non-bbox line)') + BBoxLine.__init__(self, **asdict(line)) + self._line_base_dir = self.base_dir + self.base_dir = base_dir + ocr_record.__init__(self, prediction, cuts, confidences, display_order) + + def __repr__(self) -> str: + return f'pred: {self.prediction} line: {self.line} confidences: {self.confidences}' + + def __next__(self) -> Tuple[str, int, float]: + if self.idx + 1 < len(self): + self.idx += 1 + return (self.prediction[self.idx], + self.cuts[self.idx], + self.confidences[self.idx]) + else: + raise StopIteration + + def _get_raw_item(self, key: int): + if key < 0: + key += len(self) + if key >= len(self): + raise IndexError('Index (%d) is out of range' % key) + return (self.prediction[key], + self.cuts[key], + self.confidences[key]) + + def __getitem__(self, key: Union[int, slice]): + if isinstance(key, slice): + recs = [self._get_raw_item(i) for i in range(*key.indices(len(self)))] + prediction = ''.join([x[0] for x in recs]) + box = [x[1] for x in recs] + flat_box = [point for pol in box for point in pol] + flat_box = [x for point in flat_box for x in point] + min_x, max_x = min(flat_box[::2]), max(flat_box[::2]) + min_y, max_y = min(flat_box[1::2]), max(flat_box[1::2]) + cut = ((min_x, min_y), (max_x, min_y), (max_x, max_y), (min_x, max_y)) + confidence = np.mean([x[2] for x in recs]) + return (prediction, cut, confidence) + elif isinstance(key, int): + return self._get_raw_item(key) + else: + raise TypeError('Invalid argument type') + + def logical_order(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BBoxOCRRecord': + """ + Returns the OCR record in Unicode logical order, i.e. in the order the + characters in the line would be read by a human. + + Args: + base_dir: An optional string defining the base direction (also + called paragraph direction) for the BiDi algorithm. Valid + values are 'L' or 'R'. If None is given the default + auto-resolution will be used. + """ + if self._display_order: + return self._reorder(base_dir) + else: + return self + + def display_order(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BBoxOCRRecord': + """ + Returns the OCR record in Unicode display order, i.e. ordered from left + to right inside the line. + + Args: + base_dir: An optional string defining the base direction (also + called paragraph direction) for the BiDi algorithm. Valid + values are 'L' or 'R'. If None is given the default + auto-resolution will be used. + """ + if self._display_order: + return self + else: + return self._reorder(base_dir) + + def _reorder(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BBoxOCRRecord': + storage = bd.get_empty_storage() + + if base_dir not in ('L', 'R'): + base_level = bd.get_base_level(self.prediction) + else: + base_level = {'L': 0, 'R': 1}[base_dir] + + storage['base_level'] = base_level + storage['base_dir'] = ('L', 'R')[base_level] + + bd.get_embedding_levels(self.prediction, storage) + bd.explicit_embed_and_overrides(storage) + bd.resolve_weak_types(storage) + bd.resolve_neutral_types(storage, False) + bd.resolve_implicit_levels(storage, False) + for i, j in enumerate(zip(self.prediction, self.cuts, self.confidences)): + storage['chars'][i]['record'] = j + bd.reorder_resolved_levels(storage, False) + bd.apply_mirroring(storage, False) + prediction = '' + cuts = [] + confidences = [] + for ch in storage['chars']: + # code point may have been mirrored + prediction = prediction + ch['ch'] + cuts.append(ch['record'][1]) + confidences.append(ch['record'][2]) + # carry over whole line information + line = BBoxLine(id=self.id, + bbox=self.bbox, + text=self.text, + base_dir=self._line_base_dir, + image=self.image, + tags=self.tags, + split=self.split, + regions=self.regions) + rec = BBoxOCRRecord(prediction=prediction, + cuts=cuts, + confidences=confidences, + line=line, + base_dir=base_dir, + display_order=not self._display_order) + return rec + + diff --git a/kraken/contrib/extract_lines.py b/kraken/contrib/extract_lines.py index b00db3a43..e2a41b8ba 100755 --- a/kraken/contrib/extract_lines.py +++ b/kraken/contrib/extract_lines.py @@ -1,8 +1,6 @@ #! /usr/bin/env python - import click - @click.command() @click.option('-f', '--format-type', type=click.Choice(['xml', 'alto', 'page', 'binary']), default='xml', help='Sets the input document format. In ALTO and PageXML mode all ' @@ -10,17 +8,8 @@ 'link to source images.') @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), help='Baseline detection model to use. Overrides format type and expects image files as input.') -@click.option('--repolygonize/--no-repolygonize', show_default=True, - default=False, help='Repolygonizes line data in ALTO/PageXML ' - 'files. This ensures that the trained model is compatible with the ' - 'segmenter in kraken even if the original image files either do ' - 'not contain anything but transcriptions and baseline information ' - 'or the polygon data was created using a different method. Will ' - 'be ignored in `path` mode. Note, that this option will be slow ' - 'and will not scale input images to the same size as the segmenter ' - 'does.') @click.argument('files', nargs=-1) -def cli(format_type, model, repolygonize, files): +def cli(format_type, model, files): """ A small script extracting rectified line polygons as defined in either ALTO or PageXML files or run a model to do the same. @@ -42,14 +31,14 @@ def cli(format_type, model, repolygonize, files): for doc in files: click.echo(f'Processing {doc} ', nl=False) if format_type != 'binary': - data = xml.preparse_xml_data([doc], format_type, repolygonize=repolygonize) - if len(data) > 0: - bounds = {'type': 'baselines', 'lines': [{'boundary': t['boundary'], 'baseline': t['baseline'], 'text': t['text']} for t in data]} - for idx, (im, box) in enumerate(segmentation.extract_polygons(Image.open(data[0]['image']), bounds)): + data = xml.XMLPage(doc, format_type) + if len(data.lines) > 0: + bounds = data.to_container() + for idx, (im, box) in enumerate(segmentation.extract_polygons(Image.open(bounds.imagename), bounds)): click.echo('.', nl=False) - im.save('{}.{}.jpg'.format(splitext(data[0]['image'])[0], idx)) - with open('{}.{}.gt.txt'.format(splitext(data[0]['image'])[0], idx), 'w') as fp: - fp.write(box['text']) + im.save('{}.{}.jpg'.format(splitext(bounds.imagename)[0], idx)) + with open('{}.{}.gt.txt'.format(splitext(bounds.imagename)[0], idx), 'w') as fp: + fp.write(box.text) else: with pa.memory_map(doc, 'rb') as source: ds_table = pa.ipc.open_file(source).read_all() diff --git a/kraken/contrib/forced_alignment_overlay.py b/kraken/contrib/forced_alignment_overlay.py index 5e9f19eff..1dd700c1b 100755 --- a/kraken/contrib/forced_alignment_overlay.py +++ b/kraken/contrib/forced_alignment_overlay.py @@ -35,7 +35,7 @@ def _repl_alto(fname, cuts): doc = etree.parse(fp) lines = doc.findall('.//{*}TextLine') char_idx = 0 - for line, line_cuts in zip(lines, cuts): + for line, line_cuts in zip(lines, cuts.lines): idx = 0 for el in line: if el.tag.endswith('Shape'): @@ -65,7 +65,7 @@ def _repl_page(fname, cuts): with open(fname, 'rb') as fp: doc = etree.parse(fp) lines = doc.findall('.//{*}TextLine') - for line, line_cuts in zip(lines, cuts): + for line, line_cuts in zip(lines, cuts.lines): glyphs = line.findall('../{*}Glyph/{*}Coords') for glyph, cut in zip(glyphs, line_cuts): glyph.attrib['points'] = ' '.join([','.join([str(x) for x in pt]) for pt in cut]) @@ -96,34 +96,33 @@ def cli(format_type, model, output, files): from PIL import Image, ImageDraw - from kraken.lib import models, xml + from kraken.lib.xml import XMLPage + from kraken.lib import models from kraken import align if format_type == 'alto': - fn = xml.parse_alto repl_fn = _repl_alto else: - fn = xml.parse_page repl_fn = _repl_page click.echo(f'Loading model {model}') net = models.load_any(model) for doc in files: click.echo(f'Processing {doc} ', nl=False) - data = fn(doc) - im = Image.open(data['image']).convert('RGBA') - records = align.forced_align(data, net) + data = XMLPage(doc) + im = Image.open(data.imagename).convert('RGBA') + result = align.forced_align(data.to_container, net) if output == 'overlay': tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(tmp) - for record in records: + for record in result.lines: for pol in record.cuts: c = next(cmap) draw.polygon([tuple(x) for x in pol], fill=c, outline=c[:3]) base_image = Image.alpha_composite(im, tmp) base_image.save(f'high_{os.path.basename(doc)}_algn.png') else: - repl_fn(doc, records) + repl_fn(doc, result) click.secho('\u2713', fg='green') diff --git a/kraken/contrib/heatmap_overlay.py b/kraken/contrib/heatmap_overlay.py index 8e2c7236b..0b1c22197 100755 --- a/kraken/contrib/heatmap_overlay.py +++ b/kraken/contrib/heatmap_overlay.py @@ -4,7 +4,6 @@ """ import click - @click.command() @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), help='Baseline detection model to use.') diff --git a/kraken/contrib/repolygonize.py b/kraken/contrib/repolygonize.py index db40b8987..d3f498c57 100755 --- a/kraken/contrib/repolygonize.py +++ b/kraken/contrib/repolygonize.py @@ -85,10 +85,8 @@ def _repl_page(fname, polygons): doc.write(fp, encoding='UTF-8', xml_declaration=True) if format_type == 'page': - parse_fn = xml.parse_page repl_fn = _repl_page else: - parse_fn = xml.parse_alto repl_fn = _repl_alto topline = {'topline': True, @@ -97,11 +95,11 @@ def _repl_page(fname, polygons): for doc in files: click.echo(f'Processing {doc} ') - seg = parse_fn(doc) - im = Image.open(seg['image']).convert('L') + seg = xml.XMLPage(doc).to_container() + im = Image.open(seg.imagename).convert('L') baselines = [] - for x in seg['lines']: - bl = x['baseline'] if x['baseline'] is not None else [0, 0] + for x in seg.lines: + bl = x.baseline if x.baseline is not None else [0, 0] baselines.append(bl) o = calculate_polygonal_environment(im, baselines, scale=(1800, 0), topline=topline) repl_fn(doc, o) diff --git a/kraken/contrib/segmentation_overlay.py b/kraken/contrib/segmentation_overlay.py index 229862fae..a9ff235ff 100755 --- a/kraken/contrib/segmentation_overlay.py +++ b/kraken/contrib/segmentation_overlay.py @@ -7,6 +7,7 @@ import os import click import unicodedata +import dataclasses from itertools import cycle from collections import defaultdict @@ -27,10 +28,6 @@ def slugify(value): return value @click.command() -@click.option('-f', '--format-type', type=click.Choice(['xml', 'alto', 'page']), default='xml', - help='Sets the input document format. In ALTO and PageXML mode all ' - 'data is extracted from xml files containing both baselines, polygons, and a ' - 'link to source images.') @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), help='Baseline detection model to use. Overrides format type and expects image files as input.') @click.option('-d', '--text-direction', default='horizontal-lr', @@ -48,7 +45,7 @@ def slugify(value): 'and will not scale input images to the same size as the segmenter ' 'does.') @click.argument('files', nargs=-1) -def cli(format_type, model, text_direction, repolygonize, files): +def cli(model, text_direction, repolygonize, files): """ A script producing overlays of lines and regions from either ALTO or PageXML files or run a model to do the same. @@ -64,47 +61,38 @@ def cli(format_type, model, text_direction, repolygonize, files): from kraken import blla if model is None: - if format_type == 'xml': - fn = xml.parse_xml - elif format_type == 'alto': - fn = xml.parse_alto - else: - fn = xml.parse_page for doc in files: click.echo(f'Processing {doc} ', nl=False) - data = fn(doc) + data = xml.XMLPage(doc) if repolygonize: - im = Image.open(data['image']).convert('L') - lines = data['lines'] - polygons = segmentation.calculate_polygonal_environment(im, [x['baseline'] for x in lines], scale=(1200, 0)) - data['lines'] = [{'boundary': polygon, - 'baseline': orig['baseline'], - 'text': orig['text'], - 'tags': orig['tags']} for orig, polygon in zip(lines, polygons)] + im = Image.open(data.imagename).convert('L') + lines = data.lines + polygons = segmentation.calculate_polygonal_environment(im, [x.baseline for x in lines], scale=(1200, 0)) + data.lines = [dataclasses.replace(orig, boundary=polygon) for orig, polygon in zip(lines, polygons)] # reorder lines by type lines = defaultdict(list) - for line in data['lines']: - lines[line['tags']['type']].append(line) - im = Image.open(data['image']).convert('RGBA') + for line in data.lines: + lines[line.tags['type']].append(line) + im = Image.open(data.imagename).convert('RGBA') for t, ls in lines.items(): tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(tmp) for idx, line in enumerate(ls): c = next(cmap) - if line['boundary']: - draw.polygon([tuple(x) for x in line['boundary']], fill=c, outline=c[:3]) - if line['baseline']: - draw.line([tuple(x) for x in line['baseline']], fill=bmap, width=2, joint='curve') - draw.text(line['baseline'][0], str(idx), fill=(0, 0, 0, 255)) + if line.boundary: + draw.polygon([tuple(x) for x in line.boundary], fill=c, outline=c[:3]) + if line.baseline: + draw.line([tuple(x) for x in line.baseline], fill=bmap, width=2, joint='curve') + draw.text(line.baseline[0], str(idx), fill=(0, 0, 0, 255)) base_image = Image.alpha_composite(im, tmp) base_image.save(f'high_{os.path.basename(doc)}_lines_{slugify(t)}.png') - for t, regs in data['regions'].items(): + for t, regs in data.regions.items(): tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(tmp) for reg in regs: c = next(cmap) try: - draw.polygon(reg, fill=c, outline=c[:3]) + draw.polygon(reg.boundary, fill=c, outline=c[:3]) except Exception: pass base_image = Image.alpha_composite(im, tmp) @@ -118,26 +106,26 @@ def cli(format_type, model, text_direction, repolygonize, files): res = blla.segment(im, model=net, text_direction=text_direction) # reorder lines by type lines = defaultdict(list) - for line in res['lines']: - lines[line['tags']['type']].append(line) + for line in res.lines: + lines[line.tags['type']].append(line) im = im.convert('RGBA') for t, ls in lines.items(): tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(tmp) for idx, line in enumerate(ls): c = next(cmap) - draw.polygon([tuple(x) for x in line['boundary']], fill=c, outline=c[:3]) - draw.line([tuple(x) for x in line['baseline']], fill=bmap, width=2, joint='curve') - draw.text(line['baseline'][0], str(idx), fill=(0, 0, 0, 255)) + draw.polygon([tuple(x) for x in line.boundary], fill=c, outline=c[:3]) + draw.line([tuple(x) for x in line.baseline], fill=bmap, width=2, joint='curve') + draw.text(line.baseline[0], str(idx), fill=(0, 0, 0, 255)) base_image = Image.alpha_composite(im, tmp) base_image.save(f'high_{os.path.basename(doc)}_lines_{slugify(t)}.png') - for t, regs in res['regions'].items(): + for t, regs in res.regions.items(): tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(tmp) for reg in regs: c = next(cmap) try: - draw.polygon([tuple(x) for x in reg], fill=c, outline=c[:3]) + draw.polygon([tuple(x) for x in reg.boundary], fill=c, outline=c[:3]) except Exception: pass diff --git a/kraken/ketos/__init__.py b/kraken/ketos/__init__.py index 83e56e82c..a8ddfe9a2 100644 --- a/kraken/ketos/__init__.py +++ b/kraken/ketos/__init__.py @@ -34,6 +34,7 @@ from .repo import publish from .segmentation import segtrain, segtest from .transcription import extract, transcription +from .ro import rotrain, roadd APP_NAME = 'kraken' @@ -76,6 +77,8 @@ def cli(ctx, verbose, seed, deterministic): cli.add_command(segtrain) cli.add_command(segtest) cli.add_command(publish) +cli.add_command(rotrain) +cli.add_command(roadd) # deprecated commands cli.add_command(line_generator) diff --git a/kraken/ketos/dataset.py b/kraken/ketos/dataset.py index 8d83938c3..3996a2a52 100644 --- a/kraken/ketos/dataset.py +++ b/kraken/ketos/dataset.py @@ -23,7 +23,7 @@ @click.command('compile') @click.pass_context -@click.option('-o', '--output', show_default=True, type=click.Path(), default='model', help='Output model file') +@click.option('-o', '--output', show_default=True, type=click.Path(), default='dataset.arrow', help='Output dataset file') @click.option('--workers', show_default=True, default=1, help='Number of parallel workers for text line extraction.') @click.option('-f', '--format-type', type=click.Choice(['path', 'xml', 'alto', 'page']), default='xml', show_default=True, help='Sets the training data format. In ALTO and PageXML mode all ' diff --git a/kraken/ketos/pretrain.py b/kraken/ketos/pretrain.py index 7be3cc2f0..512de415d 100644 --- a/kraken/ketos/pretrain.py +++ b/kraken/ketos/pretrain.py @@ -56,8 +56,8 @@ show_default=True, default=RECOGNITION_PRETRAIN_HYPER_PARAMS['quit'], type=click.Choice(['early', - 'dumb']), - help='Stop condition for training. Set to `early` for early stooping or `dumb` for fixed number of epochs') + 'fixed']), + help='Stop condition for training. Set to `early` for early stooping or `fixed` for fixed number of epochs') @click.option('-N', '--epochs', show_default=True, @@ -133,7 +133,8 @@ @click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data. Overrides the `-p` parameter') -@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads and workers when running on CPU.') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker processes.') +@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') @click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, help='When loading an existing model, retrieve hyperparameters from the model') @click.option('--repolygonize/--no-repolygonize', show_default=True, @@ -182,8 +183,8 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, min_epochs, lag, min_delta, device, precision, optimizer, lrate, momentum, weight_decay, warmup, schedule, gamma, step_size, sched_patience, cos_max, partition, fixed_splits, training_files, - evaluation_files, workers, load_hyper_parameters, repolygonize, - force_binarization, format_type, augment, + evaluation_files, workers, threads, load_hyper_parameters, repolygonize, + force_binarization, format_type, augment, mask_probability, mask_width, num_negatives, logit_temp, ground_truth): """ @@ -199,6 +200,7 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, raise click.BadOptionUsage('augment', 'augmentation needs the `albumentations` package installed.') import shutil + from threadpoolctl import threadpool_limits from kraken.lib.train import KrakenTrainer from kraken.lib.pretrain import PretrainDataModule, RecognitionPretrainModel @@ -275,13 +277,13 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, trainer = KrakenTrainer(accelerator=accelerator, devices=device, precision=precision, - max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'dumb' else -1, + max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'fixed' else -1, min_epochs=hyper_params['min_epochs'], enable_progress_bar=True if not ctx.meta['verbose'] else False, deterministic=ctx.meta['deterministic'], - pb_ignored_metrics=(), **val_check_interval) - trainer.fit(model, datamodule=data_module) + with threadpool_limits(limits=threads): + trainer.fit(model, datamodule=data_module) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format( diff --git a/kraken/ketos/recognition.py b/kraken/ketos/recognition.py index 781fe9f47..050316cc2 100644 --- a/kraken/ketos/recognition.py +++ b/kraken/ketos/recognition.py @@ -23,8 +23,8 @@ import pathlib from typing import List +from threadpoolctl import threadpool_limits -from kraken.lib.progress import KrakenProgressBar from kraken.lib.exceptions import KrakenInputException from kraken.lib.default_specs import RECOGNITION_HYPER_PARAMS, RECOGNITION_SPEC from .util import _validate_manifests, _expand_gt, message, to_ptl_device @@ -55,8 +55,8 @@ show_default=True, default=RECOGNITION_HYPER_PARAMS['quit'], type=click.Choice(['early', - 'dumb']), - help='Stop condition for training. Set to `early` for early stooping or `dumb` for fixed number of epochs') + 'fixed']), + help='Stop condition for training. Set to `early` for early stooping or `fixed` for fixed number of epochs') @click.option('-N', '--epochs', show_default=True, @@ -156,7 +156,8 @@ @click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data. Overrides the `-p` parameter') -@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads and workers when running on CPU.') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker processes.') +@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') @click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, help='When loading an existing model, retrieve hyperparameters from the model') @click.option('--repolygonize/--no-repolygonize', show_default=True, @@ -302,7 +303,7 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, trainer = KrakenTrainer(accelerator=accelerator, devices=device, precision=precision, - max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'dumb' else -1, + max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'fixed' else -1, min_epochs=hyper_params['min_epochs'], freeze_backbone=hyper_params['freeze_backbone'], enable_progress_bar=True if not ctx.meta['verbose'] else False, @@ -311,7 +312,8 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, log_dir=log_dir, **val_check_interval) try: - trainer.fit(model) + with threadpool_limits(limits=threads): + trainer.fit(model) except KrakenInputException as e: if e.args[0].startswith('Training data and model codec alphabets mismatch') and resize == 'fail': raise click.BadOptionUsage('resize', 'Mismatched training data for loaded model. Set option `--resize` to `new` or `add`') @@ -338,7 +340,12 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, @click.option('-d', '--device', show_default=True, default='cpu', help='Select device to use (cpu, cuda:0, cuda:1, ...)') @click.option('--pad', show_default=True, type=click.INT, default=16, help='Left and right ' 'padding around lines') -@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads when running on CPU.') +@click.option('--workers', show_default=True, default=1, + type=click.IntRange(1), + help='Number of worker processes when running on CPU.') +@click.option('--threads', show_default=True, default=1, + type=click.IntRange(1), + help='Max size of thread pools for OpenMP/BLAS operations.') @click.option('--reorder/--no-reorder', show_default=True, default=True, help='Reordering of code points to display order') @click.option('--base-dir', show_default=True, default='auto', type=click.Choice(['L', 'R', 'auto']), help='Set base text ' @@ -371,8 +378,8 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, 'collections of pre-extracted text line images.') @click.argument('test_set', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False)) def test(ctx, batch_size, model, evaluation_files, device, pad, workers, - reorder, base_dir, normalization, normalize_whitespace, repolygonize, - force_binarization, format_type, test_set): + threads, reorder, base_dir, normalization, normalize_whitespace, + repolygonize, force_binarization, format_type, test_set): """ Evaluate on a test set. """ @@ -384,12 +391,13 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, from kraken.serialization import render_report from kraken.lib import models - from kraken.lib.xml import preparse_xml_data + from kraken.lib.xml import XMLPage from kraken.lib.dataset import (global_align, compute_confusions, PolygonGTDataset, GroundTruthDataset, ImageInputTransforms, ArrowIPCRecognitionDataset, collate_sequences) + from kraken.lib.progress import KrakenProgressBar logger.info('Building test set from {} line images'.format(len(test_set) + len(evaluation_files))) @@ -401,8 +409,6 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, test_set = list(test_set) - # set number of OpenMP threads - next(iter(nn.values())).nn.set_num_threads(1) if evaluation_files: test_set.extend(evaluation_files) @@ -413,7 +419,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, if format_type in ['xml', 'page', 'alto']: if repolygonize: message('Repolygonizing data') - test_set = preparse_xml_data(test_set, format_type, repolygonize) + test_set = [{'page': XMLPage(file, filetype=format_type).to_container()} for file in test_set] valid_norm = False DatasetClass = PolygonGTDataset elif format_type == 'binary': @@ -439,62 +445,64 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, reorder = base_dir acc_list = [] - for p, net in nn.items(): - algn_gt: List[str] = [] - algn_pred: List[str] = [] - chars = 0 - error = 0 - message('Evaluating {}'.format(p)) - logger.info('Evaluating {}'.format(p)) - batch, channels, height, width = net.nn.input - ts = ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm, force_binarization) - ds = DatasetClass(normalization=normalization, - whitespace_normalization=normalize_whitespace, - reorder=reorder, - im_transforms=ts) - for line in test_set: - try: - ds.add(**line) - except KrakenInputException as e: - logger.info(e) - # don't encode validation set as the alphabets may not match causing encoding failures - ds.no_encode() - ds_loader = DataLoader(ds, - batch_size=batch_size, - num_workers=workers, - pin_memory=True, - collate_fn=collate_sequences) - - with KrakenProgressBar() as progress: - batches = len(ds_loader) - pred_task = progress.add_task('Evaluating', total=batches, visible=True if not ctx.meta['verbose'] else False) - - for batch in ds_loader: - im = batch['image'] - text = batch['target'] - lens = batch['seq_lens'] + + with threadpool_limits(limits=threads): + for p, net in nn.items(): + algn_gt: List[str] = [] + algn_pred: List[str] = [] + chars = 0 + error = 0 + message('Evaluating {}'.format(p)) + logger.info('Evaluating {}'.format(p)) + batch, channels, height, width = net.nn.input + ts = ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm, force_binarization) + ds = DatasetClass(normalization=normalization, + whitespace_normalization=normalize_whitespace, + reorder=reorder, + im_transforms=ts) + for line in test_set: try: - pred = net.predict_string(im, lens) - for x, y in zip(pred, text): - chars += len(y) - c, algn1, algn2 = global_align(y, x) - algn_gt.extend(algn1) - algn_pred.extend(algn2) - error += c - except FileNotFoundError as e: - batches -= 1 - progress.update(pred_task, total=batches) - logger.warning('{} {}. Skipping.'.format(e.strerror, e.filename)) + ds.add(**line) except KrakenInputException as e: - batches -= 1 - progress.update(pred_task, total=batches) - logger.warning(str(e)) - progress.update(pred_task, advance=1) - - acc_list.append((chars - error) / chars) - confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred) - rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs) - logger.info(rep) - message(rep) + logger.info(e) + # don't encode validation set as the alphabets may not match causing encoding failures + ds.no_encode() + ds_loader = DataLoader(ds, + batch_size=batch_size, + num_workers=workers, + pin_memory=True, + collate_fn=collate_sequences) + + with KrakenProgressBar() as progress: + batches = len(ds_loader) + pred_task = progress.add_task('Evaluating', total=batches, visible=True if not ctx.meta['verbose'] else False) + + for batch in ds_loader: + im = batch['image'] + text = batch['target'] + lens = batch['seq_lens'] + try: + pred = net.predict_string(im, lens) + for x, y in zip(pred, text): + chars += len(y) + c, algn1, algn2 = global_align(y, x) + algn_gt.extend(algn1) + algn_pred.extend(algn2) + error += c + except FileNotFoundError as e: + batches -= 1 + progress.update(pred_task, total=batches) + logger.warning('{} {}. Skipping.'.format(e.strerror, e.filename)) + except KrakenInputException as e: + batches -= 1 + progress.update(pred_task, total=batches) + logger.warning(str(e)) + progress.update(pred_task, advance=1) + + acc_list.append((chars - error) / chars) + confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred) + rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs) + logger.info(rep) + message(rep) logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100)) message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100)) diff --git a/kraken/ketos/repo.py b/kraken/ketos/repo.py index 32a49ac5e..52c9db8e3 100644 --- a/kraken/ketos/repo.py +++ b/kraken/ketos/repo.py @@ -22,8 +22,6 @@ import click import logging -from kraken.lib.progress import KrakenDownloadProgressBar - from .util import message logging.captureWarnings(True) @@ -52,6 +50,7 @@ def publish(ctx, metadata, access_token, private, model): from kraken import repo from kraken.lib import models + from kraken.lib.progress import KrakenDownloadProgressBar with pkg_resources.resource_stream('kraken', 'metadata.schema.json') as fp: schema = json.load(fp) diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py new file mode 100644 index 000000000..006c1b8bd --- /dev/null +++ b/kraken/ketos/ro.py @@ -0,0 +1,297 @@ +# +# Copyright 2022 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +kraken.ketos.ro +~~~~~~~~~~~~~~~ + +Command line driver for reading order training, evaluation, and handling. +""" +import click +import pathlib +import logging + +from PIL import Image +from typing import Dict + +from kraken.lib.exceptions import KrakenInputException +from kraken.lib.default_specs import READING_ORDER_HYPER_PARAMS + +from kraken.ketos.util import _validate_manifests, _expand_gt, message, to_ptl_device + +logging.captureWarnings(True) +logger = logging.getLogger('kraken') + +# raise default max image size to 20k * 20k pixels +Image.MAX_IMAGE_PIXELS = 20000 ** 2 + +@click.command('rotrain') +@click.pass_context +@click.option('-B', '--batch-size', show_default=True, type=click.INT, + default=READING_ORDER_HYPER_PARAMS['batch_size'], help='batch sample size') +@click.option('-o', '--output', show_default=True, type=click.Path(), default='model', help='Output model file') +@click.option('-i', '--load', show_default=True, type=click.Path(exists=True, + readable=True), help='Load existing file to continue training') +@click.option('-F', '--freq', show_default=True, default=READING_ORDER_HYPER_PARAMS['freq'], type=click.FLOAT, + help='Model saving and report generation frequency in epochs ' + 'during training. If frequency is >1 it must be an integer, ' + 'i.e. running validation every n-th epoch.') +@click.option('-q', + '--quit', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['quit'], + type=click.Choice(['early', + 'fixed']), + help='Stop condition for training. Set to `early` for early stopping or `fixed` for fixed number of epochs') +@click.option('-N', + '--epochs', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['epochs'], + help='Number of epochs to train for') +@click.option('--min-epochs', + show_default=True, + default=['min_epochs'], + help='Minimal number of epochs to train for when using early stopping.') +@click.option('--lag', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['lag'], + help='Number of evaluations (--report frequence) to wait before stopping training without improvement') +@click.option('--min-delta', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['min_delta'], + type=click.FLOAT, + help='Minimum improvement between epochs to reset early stopping. By default it scales the delta by the best loss') +@click.option('-d', '--device', show_default=True, default='cpu', help='Select device to use (cpu, cuda:0, cuda:1, ...)') +@click.option('--precision', default='32', type=click.Choice(['32', '16']), help='set tensor precision') +@click.option('--optimizer', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['optimizer'], + type=click.Choice(['Adam', + 'SGD', + 'RMSprop', + 'Lamb']), + help='Select optimizer') +@click.option('-r', '--lrate', show_default=True, default=READING_ORDER_HYPER_PARAMS['lrate'], help='Learning rate') +@click.option('-m', '--momentum', show_default=True, default=READING_ORDER_HYPER_PARAMS['momentum'], help='Momentum') +@click.option('-w', '--weight-decay', show_default=True, + default=READING_ORDER_HYPER_PARAMS['weight_decay'], help='Weight decay') +@click.option('--warmup', show_default=True, type=float, + default=READING_ORDER_HYPER_PARAMS['warmup'], help='Number of samples to ramp up to `lrate` initial learning rate.') +@click.option('--schedule', + show_default=True, + type=click.Choice(['constant', + '1cycle', + 'exponential', + 'cosine', + 'step', + 'reduceonplateau']), + default=READING_ORDER_HYPER_PARAMS['schedule'], + help='Set learning rate scheduler. For 1cycle, cycle length is determined by the `--step-size` option.') +@click.option('-g', + '--gamma', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['gamma'], + help='Decay factor for exponential, step, and reduceonplateau learning rate schedules') +@click.option('-ss', + '--step-size', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['step_size'], + help='Number of validation runs between learning rate decay for exponential and step LR schedules') +@click.option('--sched-patience', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['rop_patience'], + help='Minimal number of validation runs between LR reduction for reduceonplateau LR schedule.') +@click.option('--cos-max', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['cos_t_max'], + help='Epoch of minimal learning rate for cosine LR scheduler.') +@click.option('-p', '--partition', show_default=True, default=0.9, + help='Ground truth data partition ratio between train/validation set') +@click.option('-t', '--training-files', show_default=True, default=None, multiple=True, + callback=_validate_manifests, type=click.File(mode='r', lazy=True), + help='File(s) with additional paths to training data') +@click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, + callback=_validate_manifests, type=click.File(mode='r', lazy=True), + help='File(s) with paths to evaluation data. Overrides the `-p` parameter') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker proesses.') +@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') +@click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, + help='When loading an existing model, retrieve hyper-parameters from the model') +@click.option('-f', '--format-type', type=click.Choice(['xml', 'alto', 'page']), default='xml', + help='Sets the training data format. In ALTO and PageXML mode all ' + 'data is extracted from xml files containing both baselines and a ' + 'link to source images.') +@click.option('--logger', 'pl_logger', show_default=True, type=click.Choice(['tensorboard']), default=None, + help='Logger used by PyTorch Lightning to track metrics such as loss and accuracy.') +@click.option('--log-dir', show_default=True, type=click.Path(exists=True, dir_okay=True, writable=True), + help='Path to directory where the logger will store the logs. If not set, a directory will be created in the current working directory.') +@click.option('--level', show_default=True, type=click.Choice(['baselines', 'regions']), default='baselines', + help='Selects level to train reading order model on.') +@click.option('--reading-order', show_default=True, default=None, + help='Select reading order to train. Defaults to `line_implicit`/`region_implicit`') +@click.argument('ground_truth', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False)) +def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag, + min_delta, device, precision, optimizer, lrate, momentum, + weight_decay, warmup, schedule, gamma, step_size, sched_patience, + cos_max, partition, training_files, evaluation_files, workers, + threads, load_hyper_parameters, format_type, pl_logger, log_dir, + level, reading_order, ground_truth): + """ + Trains a baseline labeling model for layout analysis + """ + import shutil + + from threadpoolctl import threadpool_limits + + from kraken.lib.ro import ROModel + from kraken.lib.train import KrakenTrainer + from kraken.lib.progress import KrakenProgressBar + + if not (0 <= freq <= 1) and freq % 1.0 != 0: + raise click.BadOptionUsage('freq', 'freq needs to be either in the interval [0,1.0] or a positive integer.') + + if pl_logger == 'tensorboard': + try: + import tensorboard + except ImportError: + raise click.BadOptionUsage('logger', 'tensorboard logger needs the `tensorboard` package installed.') + + if log_dir is None: + log_dir = pathlib.Path.cwd() + + logger.info('Building ground truth set from {} document images'.format(len(ground_truth) + len(training_files))) + + # populate hyperparameters from command line args + hyper_params = READING_ORDER_HYPER_PARAMS.copy() + hyper_params.update({'batch_size': batch_size, + 'freq': freq, + 'quit': quit, + 'epochs': epochs, + 'min_epochs': min_epochs, + 'lag': lag, + 'min_delta': min_delta, + 'optimizer': optimizer, + 'lrate': lrate, + 'momentum': momentum, + 'weight_decay': weight_decay, + 'warmup': warmup, + 'schedule': schedule, + 'gamma': gamma, + 'step_size': step_size, + 'rop_patience': sched_patience, + 'cos_t_max': cos_max, + 'pl_logger': pl_logger,}) + + # disable automatic partition when given evaluation set explicitly + if evaluation_files: + partition = 1 + ground_truth = list(ground_truth) + + # merge training_files into ground_truth list + if training_files: + ground_truth.extend(training_files) + + if len(ground_truth) == 0: + raise click.UsageError('No training data was provided to the train command. Use `-t` or the `ground_truth` argument.') + + try: + accelerator, device = to_ptl_device(device) + except Exception as e: + raise click.BadOptionUsage('device', str(e)) + + if hyper_params['freq'] > 1: + val_check_interval = {'check_val_every_n_epoch': int(hyper_params['freq'])} + else: + val_check_interval = {'val_check_interval': hyper_params['freq']} + + if load: + model = ROModel.load_from_checkpoint(load, + training_data=ground_truth, + evaluation_data=evaluation_files, + partition=partition, + num_workers=workers, + load_hyper_parameters=load_hyper_parameters, + format_type=format_type) + else: + model = ROModel(hyper_params, + output=output, + training_data=ground_truth, + evaluation_data=evaluation_files, + partition=partition, + num_workers=workers, + load_hyper_parameters=load_hyper_parameters, + format_type=format_type, + level=level, + reading_order=reading_order) + + message(f'Training RO on following {level} types:') + for k, v in model.train_set.dataset.class_mapping.items(): + message(f' {k}\t{v}') + + if len(model.train_set) == 0: + raise click.UsageError('No valid training data was provided to the train command. Use `-t` or the `ground_truth` argument.') + + trainer = KrakenTrainer(accelerator=accelerator, + devices=device, + max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'fixed' else -1, + min_epochs=hyper_params['min_epochs'], + enable_progress_bar=True if not ctx.meta['verbose'] else False, + deterministic=ctx.meta['deterministic'], + precision=int(precision), + pl_logger=pl_logger, + log_dir=log_dir, + **val_check_interval) + + with threadpool_limits(limits=threads): + trainer.fit(model) + + if quit == 'early': + message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format( + output, model.best_epoch, model.best_metric)) + logger.info('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format( + output, model.best_epoch, model.best_metric)) + shutil.copy(f'{output}_{model.best_epoch}.mlmodel', f'{output}_best.mlmodel') + + +@click.command('roadd') +@click.pass_context +@click.option('-o', '--output', show_default=True, type=click.Path(), default='combined_seg.mlmodel', help='Combined output model file') +@click.option('-r', '--ro-model', show_default=True, type=click.Path(exists=True, readable=True), help='Reading order model to load into segmentation model') +@click.option('-i', '--seg-model', show_default=True, type=click.Path(exists=True, readable=True), help='Segmentation model to load') +def roadd(ctx, output, ro_model, seg_model): + """ + Combines a reading order model with a segmentation model. + """ + from kraken.lib import vgsl + from kraken.lib.ro import ROModel + from kraken.lib.train import KrakenTrainer + + message(f'Adding {ro_model} reading order model to {seg_model}.') + ro_net = ROModel.load_from_checkpoint(ro_model) + message('Line classes known to RO model:') + for k, v in ro_net.class_mapping.items(): + message(f' {k}\t{v}') + seg_net = vgsl.TorchVGSLModel.load_model(seg_model) + if seg_net.model_type != 'segmentation': + raise click.UsageError(f'Model {seg_model} is invalid {seg_net.model_type} model (expected `segmentation`).') + message('Line classes known to segmentation model:') + for k, v in seg_net.user_metadata['class_mapping']['baselines'].items(): + message(f' {k}\t{v}') + if ro_net.class_mapping.keys() != seg_net.user_metadata['class_mapping']['baselines'].keys(): + raise click.UsageError(f'Model {seg_model} and {ro_model} class mappings mismatch.') + + seg_net.aux_layers = {'ro_model': ro_net.ro_net} + seg_net.user_metadata['ro_class_mapping'] = ro_net.class_mapping + message(f'Saving combined model to {output}') + seg_net.save_model(output) diff --git a/kraken/ketos/segmentation.py b/kraken/ketos/segmentation.py index 54afa9ae5..9db1060e6 100644 --- a/kraken/ketos/segmentation.py +++ b/kraken/ketos/segmentation.py @@ -24,7 +24,6 @@ from PIL import Image -from kraken.lib.progress import KrakenProgressBar from kraken.lib.exceptions import KrakenInputException from kraken.lib.default_specs import SEGMENTATION_HYPER_PARAMS, SEGMENTATION_SPEC @@ -76,8 +75,8 @@ def _validate_merging(ctx, param, value): show_default=True, default=SEGMENTATION_HYPER_PARAMS['quit'], type=click.Choice(['early', - 'dumb']), - help='Stop condition for training. Set to `early` for early stopping or `dumb` for fixed number of epochs') + 'fixed']), + help='Stop condition for training. Set to `early` for early stopping or `fixed` for fixed number of epochs') @click.option('-N', '--epochs', show_default=True, @@ -152,7 +151,8 @@ def _validate_merging(ctx, param, value): @click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data. Overrides the `-p` parameter') -@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads and workers when running on CPU.') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker proesses.') +@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') @click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, help='When loading an existing model, retrieve hyper-parameters from the model') @click.option('--force-binarization/--no-binarization', show_default=True, @@ -219,7 +219,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, min_epochs, lag, min_delta, device, precision, optimizer, lrate, momentum, weight_decay, warmup, schedule, gamma, step_size, sched_patience, cos_max, partition, training_files, - evaluation_files, workers, load_hyper_parameters, + evaluation_files, workers, threads, load_hyper_parameters, force_binarization, format_type, suppress_regions, suppress_baselines, valid_regions, valid_baselines, merge_regions, merge_baselines, bounding_regions, @@ -229,7 +229,10 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, """ import shutil + from threadpoolctl import threadpool_limits + from kraken.lib.train import SegmentationModel, KrakenTrainer + from kraken.lib.progress import KrakenProgressBar if resize != 'fail' and not load: raise click.BadOptionUsage('resize', 'resize option requires loading an existing model') @@ -339,7 +342,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, trainer = KrakenTrainer(accelerator=accelerator, devices=device, precision=precision, - max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'dumb' else -1, + max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'fixed' else -1, min_epochs=hyper_params['min_epochs'], enable_progress_bar=True if not ctx.meta['verbose'] else False, deterministic=ctx.meta['deterministic'], @@ -347,7 +350,8 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, log_dir=log_dir, **val_check_interval) - trainer.fit(model) + with threadpool_limits(limits=threads): + trainer.fit(model) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format( @@ -365,7 +369,10 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data.') @click.option('-d', '--device', show_default=True, default='cpu', help='Select device to use (cpu, cuda:0, cuda:1, ...)') -@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads when running on CPU.') +@click.option('--workers', default=1, show_default=True, type=click.IntRange(1), + help='Number of worker processes for data loading.') +@click.option('--threads', default=1, show_default=True, type=click.IntRange(1), + help='Size of thread pools for intra-op parallelization') @click.option('--force-binarization/--no-binarization', show_default=True, default=False, help='Forces input images to be binary, otherwise ' 'the appropriate color format will be auto-determined through the ' @@ -403,7 +410,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, @click.option("--threshold", type=click.FloatRange(.01, .99), default=.3, show_default=True, help="Threshold for heatmap binarization. Training threshold is .3, prediction is .5") @click.argument('test_set', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False)) -def segtest(ctx, model, evaluation_files, device, workers, threshold, +def segtest(ctx, model, evaluation_files, device, workers, threads, threshold, force_binarization, format_type, test_set, suppress_regions, suppress_baselines, valid_regions, valid_baselines, merge_regions, merge_baselines, bounding_regions): @@ -413,6 +420,7 @@ def segtest(ctx, model, evaluation_files, device, workers, threshold, if not model: raise click.UsageError('No model to evaluate given.') + from threadpoolctl import threadpool_limits from torch.utils.data import DataLoader import torch import torch.nn.functional as F @@ -502,46 +510,47 @@ def segtest(ctx, model, evaluation_files, device, workers, threshold, with KrakenProgressBar() as progress: batches = len(ds_loader) pred_task = progress.add_task('Evaluating', total=batches, visible=True if not ctx.meta['verbose'] else False) - for batch in ds_loader: - x, y = batch['image'], batch['target'] - try: - pred, _ = nn.nn(x) - # scale target to output size - y = F.interpolate(y, size=(pred.size(2), pred.size(3))).squeeze(0).bool() - pred = pred.squeeze() > threshold - pred = pred.view(pred.size(0), -1) - y = y.view(y.size(0), -1) - pages.append({ - 'intersections': (y & pred).sum(dim=1, dtype=torch.double), - 'unions': (y | pred).sum(dim=1, dtype=torch.double), - 'corrects': torch.eq(y, pred).sum(dim=1, dtype=torch.double), - 'cls_cnt': y.sum(dim=1, dtype=torch.double), - 'all_n': torch.tensor(y.size(1), dtype=torch.double, device=device) - }) - if lines_idx: - y_baselines = y[lines_idx].sum(dim=0, dtype=torch.bool) - pred_baselines = pred[lines_idx].sum(dim=0, dtype=torch.bool) - pages[-1]["baselines"] = { - 'intersections': (y_baselines & pred_baselines).sum(dim=0, dtype=torch.double), - 'unions': (y_baselines | pred_baselines).sum(dim=0, dtype=torch.double), - } - if regions_idx: - y_regions_idx = y[regions_idx].sum(dim=0, dtype=torch.bool) - pred_regions_idx = pred[regions_idx].sum(dim=0, dtype=torch.bool) - pages[-1]["regions"] = { - 'intersections': (y_regions_idx & pred_regions_idx).sum(dim=0, dtype=torch.double), - 'unions': (y_regions_idx | pred_regions_idx).sum(dim=0, dtype=torch.double), - } - - except FileNotFoundError as e: - batches -= 1 - progress.update(pred_task, total=batches) - logger.warning('{} {}. Skipping.'.format(e.strerror, e.filename)) - except KrakenInputException as e: - batches -= 1 - progress.update(pred_task, total=batches) - logger.warning(str(e)) - progress.update(pred_task, advance=1) + with threadpool_limits(limits=threads): + for batch in ds_loader: + x, y = batch['image'], batch['target'] + try: + pred, _ = nn.nn(x) + # scale target to output size + y = F.interpolate(y, size=(pred.size(2), pred.size(3))).squeeze(0).bool() + pred = pred.squeeze() > threshold + pred = pred.view(pred.size(0), -1) + y = y.view(y.size(0), -1) + pages.append({ + 'intersections': (y & pred).sum(dim=1, dtype=torch.double), + 'unions': (y | pred).sum(dim=1, dtype=torch.double), + 'corrects': torch.eq(y, pred).sum(dim=1, dtype=torch.double), + 'cls_cnt': y.sum(dim=1, dtype=torch.double), + 'all_n': torch.tensor(y.size(1), dtype=torch.double, device=device) + }) + if lines_idx: + y_baselines = y[lines_idx].sum(dim=0, dtype=torch.bool) + pred_baselines = pred[lines_idx].sum(dim=0, dtype=torch.bool) + pages[-1]["baselines"] = { + 'intersections': (y_baselines & pred_baselines).sum(dim=0, dtype=torch.double), + 'unions': (y_baselines | pred_baselines).sum(dim=0, dtype=torch.double), + } + if regions_idx: + y_regions_idx = y[regions_idx].sum(dim=0, dtype=torch.bool) + pred_regions_idx = pred[regions_idx].sum(dim=0, dtype=torch.bool) + pages[-1]["regions"] = { + 'intersections': (y_regions_idx & pred_regions_idx).sum(dim=0, dtype=torch.double), + 'unions': (y_regions_idx | pred_regions_idx).sum(dim=0, dtype=torch.double), + } + + except FileNotFoundError as e: + batches -= 1 + progress.update(pred_task, total=batches) + logger.warning('{} {}. Skipping.'.format(e.strerror, e.filename)) + except KrakenInputException as e: + batches -= 1 + progress.update(pred_task, total=batches) + logger.warning(str(e)) + progress.update(pred_task, advance=1) # Accuracy / pixel corrects = torch.stack([x['corrects'] for x in pages], -1).sum(dim=-1) diff --git a/kraken/ketos/transcription.py b/kraken/ketos/transcription.py index 490c0ac4e..dd4402c66 100644 --- a/kraken/ketos/transcription.py +++ b/kraken/ketos/transcription.py @@ -27,7 +27,6 @@ from typing import IO, Any, cast from bidi.algorithm import get_display -from kraken.lib.progress import KrakenProgressBar from .util import message logging.captureWarnings(True) @@ -68,6 +67,7 @@ def extract(ctx, binarize, normalization, normalize_whitespace, reorder, from lxml import html, etree from kraken import binarization + from kraken.lib.progress import KrakenProgressBar try: os.mkdir(output) @@ -172,6 +172,7 @@ def transcription(ctx, text_direction, scale, bw, maxcolseps, from kraken import binarization from kraken.lib import models + from kraken.lib.progress import KrakenProgressBar ti = transcribe.TranscriptionInterface(font, font_style) diff --git a/kraken/ketos/util.py b/kraken/ketos/util.py index e71b53505..b37b298b2 100644 --- a/kraken/ketos/util.py +++ b/kraken/ketos/util.py @@ -54,7 +54,7 @@ def message(msg, **styles): def to_ptl_device(device: str) -> Tuple[str, Optional[List[int]]]: - if any([device == x for x in ['cpu', 'mps']]): + if device in ['cpu', 'mps']: return device, 'auto' elif any([device.startswith(x) for x in ['tpu', 'cuda', 'hpu', 'ipu']]): dev, idx = device.split(':') diff --git a/kraken/kraken.py b/kraken/kraken.py index 6788cdeef..27a0f5cef 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -21,18 +21,18 @@ import os import warnings import logging +import dataclasses import pkg_resources -from typing import Dict, Union, List, cast, Any, IO, Callable +from PIL import Image from pathlib import Path -from rich.traceback import install from functools import partial -from PIL import Image +from rich.traceback import install +from typing import Dict, Union, List, cast, Any, IO, Callable import click from kraken.lib import log -from kraken.lib.progress import KrakenProgressBar, KrakenDownloadProgressBar warnings.simplefilter('ignore', UserWarning) @@ -45,7 +45,6 @@ APP_NAME = 'kraken' SEGMENTATION_DEFAULT_MODEL = pkg_resources.resource_filename(__name__, 'blla.mlmodel') DEFAULT_MODEL = ['en_best.mlmodel'] -LEGACY_MODEL_DIR = '/usr/local/share/ocropus' # raise default max image size to 20k * 20k pixels Image.MAX_IMAGE_PIXELS = 20000 ** 2 @@ -57,15 +56,9 @@ def message(msg: str, **styles) -> None: def get_input_parser(type_str: str) -> Callable[[str], Dict[str, Any]]: - if type_str == 'alto': - from kraken.lib.xml import parse_alto - return parse_alto - elif type_str == 'page': - from kraken.lib.xml import parse_page - return parse_page - elif type_str == 'xml': - from kraken.lib.xml import parse_xml - return parse_xml + if type_str in ['alto', 'page', 'xml']: + from kraken.lib.xml import XMLPage + return XMLPage elif type_str == 'image': return Image.open @@ -78,7 +71,7 @@ def binarizer(threshold, zoom, escale, border, perc, range, low, high, input, ou ctx = click.get_current_context() if ctx.meta['first_process']: if ctx.meta['input_format_type'] != 'image': - input = get_input_parser(ctx.meta['input_format_type'])(input)['image'] + input = get_input_parser(ctx.meta['input_format_type'])(input).imagename ctx.meta['first_process'] = False else: raise click.UsageError('Binarization has to be the initial process.') @@ -124,14 +117,13 @@ def segmenter(legacy, model, text_direction, scale, maxcolseps, black_colseps, remove_hlines, pad, mask, device, input, output) -> None: import json - from kraken import pageseg - from kraken import blla + from kraken import blla, pageseg ctx = click.get_current_context() if ctx.meta['first_process']: if ctx.meta['input_format_type'] != 'image': - input = get_input_parser(ctx.meta['input_format_type'])(input)['image'] + input = get_input_parser(ctx.meta['input_format_type'])(input).imagename ctx.meta['first_process'] = False if 'base_image' not in ctx.meta: @@ -179,15 +171,20 @@ def segmenter(legacy, model, text_direction, scale, maxcolseps, black_colseps, else: with click.open_file(output, 'w') as fp: fp = cast(IO[Any], fp) - json.dump(res, fp) + json.dump(dataclasses.asdict(res), fp) message('\u2713', fg='green') def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, output) -> None: import json + import uuid + import dataclasses from kraken import rpred + from kraken.containers import Segmentation, BBoxLine + + from kraken.lib.progress import KrakenProgressBar ctx = click.get_current_context() @@ -198,12 +195,11 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, if ctx.meta['first_process']: if ctx.meta['input_format_type'] != 'image': doc = get_input_parser(ctx.meta['input_format_type'])(input) - ctx.meta['base_image'] = doc['image'] - doc['text_direction'] = 'horizontal-lr' - if doc['base_dir'] and bidi_reordering is True: - message(f'Setting base text direction for BiDi reordering to {doc["base_dir"]} (from XML input file)') - bidi_reordering = doc['base_dir'] - bounds = doc + ctx.meta['base_image'] = doc.imagename + if doc.base_dir and bidi_reordering is True: + message(f'Setting base text direction for BiDi reordering to {doc.base_dir} (from XML input file)') + bidi_reordering = doc.base_dir + bounds = doc.to_container() try: im = Image.open(ctx.meta['base_image']) except IOError as e: @@ -213,14 +209,15 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, with click.open_file(input, 'r') as fp: try: fp = cast(IO[Any], fp) - bounds = json.load(fp) + bounds = Segmentation(**json.load(fp)) except ValueError as e: raise click.UsageError(f'{input} invalid segmentation: {str(e)}') elif not bounds: if no_segmentation: - bounds = {'script_detection': False, - 'text_direction': 'horizontal-lr', - 'boxes': [(0, 0) + im.size]} + bounds = Segmentation(type='bbox', + text_direction='horizontal-lr', + lines=[BBoxLine(id=uuid.uuid4(), + bbox=((0, 0), (0, im.size[1]), im.size, (im.size[0], 0)))]) else: raise click.UsageError('No line segmentation given. Add one with the input or run `segment` first.') elif no_segmentation: @@ -228,7 +225,7 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, tags = set() # script detection - if 'script_detection' in bounds and bounds['script_detection']: + if bounds.script_detection: it = rpred.mm_rpred(model, im, bounds, pad, bidi_reordering=bidi_reordering, tags_ignore=tags_ignore) @@ -243,6 +240,7 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, for pred in it: preds.append(pred) progress.update(pred_task, advance=1) + results = dataclasses.replace(it.bounds, lines=preds, imagename=ctx.meta['base_image']) ctx = click.get_current_context() with click.open_file(output, 'w', encoding='utf-8') as fp: @@ -251,12 +249,10 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, logger.info('Serializing as {} into {}'.format(ctx.meta['output_mode'], output)) if ctx.meta['output_mode'] != 'native': from kraken import serialization - fp.write(serialization.serialize(records=preds, - image_name=ctx.meta['base_image'], + fp.write(serialization.serialize(results=results, image_size=Image.open(ctx.meta['base_image']).size, writing_mode=ctx.meta['text_direction'], scripts=tags, - regions=bounds['regions'] if 'regions' in bounds else None, template=ctx.meta['output_template'], template_source='custom' if ctx.meta['output_mode'] == 'template' else 'native', processing_steps=ctx.meta['steps'])) @@ -305,8 +301,10 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, help='Raises the exception that caused processing to fail in the case of an error') @click.option('-2', '--autocast', default=False, show_default=True, flag_value=True, help='On compatible devices, uses autocast for `segment` which lower the memory usage.') +@click.option('--threads', default=1, show_default=True, type=click.IntRange(1), + help='Size of thread pools for intra-op parallelization') def cli(input, batch_input, suffix, verbose, format_type, pdf_format, - serializer, template, device, raise_on_error, autocast): + serializer, template, device, raise_on_error, autocast, threads): """ Base command for recognition functionality. @@ -336,6 +334,7 @@ def cli(input, batch_input, suffix, verbose, format_type, pdf_format, ctx.meta['verbose'] = verbose ctx.meta['steps'] = [] ctx.meta["autocast"] = autocast + ctx.meta['threads'] = threads log.set_logger(logger, level=30 - min(10 * verbose, 20)) @@ -349,6 +348,9 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty import uuid import tempfile + from threadpoolctl import threadpool_limits + from kraken.lib.progress import KrakenProgressBar + ctx = click.get_current_context() input = list(input) @@ -413,7 +415,8 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty for idx, (task, input, output) in enumerate(zip(subcommands, fc, fc[1:])): if len(fc) - 2 == idx: ctx.meta['last_process'] = True - task(input=input, output=output) + with threadpool_limits(limits=ctx.meta['threads']): + task(input=input, output=output) except Exception as e: logger.error(f'Failed processing {io_pair[0]}: {str(e)}') if ctx.meta['raise_failed']: @@ -567,9 +570,7 @@ def _validate_mm(ctx, param, value): show_default=True, type=click.Choice(['horizontal-tb', 'vertical-lr', 'vertical-rl']), help='Sets principal text direction in serialization output') -@click.option('--threads', default=1, show_default=True, type=click.IntRange(1), - help='Number of threads to use for OpenMP parallelization.') -def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction, threads): +def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction): """ Recognizes text in line images. """ @@ -581,14 +582,12 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction, thr if reorder and base_dir != 'auto': reorder = base_dir - # first we try to find the model in the absolue path, then ~/.kraken, then - # LEGACY_MODEL_DIR + # first we try to find the model in the absolue path, then ~/.kraken nm = {} # type: Dict[str, models.TorchSeqRecognizer] ign_tags = model.pop('ignore') for k, v in model.items(): search = [v, - os.path.join(click.get_app_dir(APP_NAME), v), - os.path.join(LEGACY_MODEL_DIR, v)] + os.path.join(click.get_app_dir(APP_NAME), v)] location = None for loc in search: if os.path.isfile(loc): @@ -613,8 +612,6 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction, thr nn = defaultdict(lambda: nm['default']) # type: Dict[str, models.TorchSeqRecognizer] nn.update(nm) nm = nn - # thread count is global so setting it once is sufficient - nm[k].nn.set_num_threads(threads) ctx.meta['steps'].append({'category': 'processing', 'description': 'Text line recognition', @@ -667,6 +664,7 @@ def list_models(ctx): Lists models in the repository. """ from kraken import repo + from kraken.lib.progress import KrakenProgressBar with KrakenProgressBar() as progress: download_task = progress.add_task('Retrieving model list', total=0, visible=True if not ctx.meta['verbose'] else False) @@ -684,6 +682,7 @@ def get(ctx, model_id): Retrieves a model from the repository. """ from kraken import repo + from kraken.lib.progress import KrakenDownloadProgressBar try: os.makedirs(click.get_app_dir(APP_NAME)) diff --git a/kraken/lib/arrow_dataset.py b/kraken/lib/arrow_dataset.py index 149c58590..bac1c1a0a 100755 --- a/kraken/lib/arrow_dataset.py +++ b/kraken/lib/arrow_dataset.py @@ -28,9 +28,10 @@ from collections import Counter from typing import Optional, List, Union, Callable, Tuple, Dict from multiprocessing import Pool +from kraken.containers import Segmentation from kraken.lib import functional_im_transforms as F_t from kraken.lib.segmentation import extract_polygons -from kraken.lib.xml import parse_xml, parse_alto, parse_page +from kraken.lib.xml import XMLPage from kraken.lib.util import is_bitonal, make_printable from kraken.lib.exceptions import KrakenInputException from os import extsep, PathLike @@ -43,27 +44,33 @@ def _extract_line(xml_record, skip_empty_lines: bool = True): lines = [] try: - im = Image.open(xml_record['image']) + im = Image.open(xml_record.imagename) except (FileNotFoundError, UnidentifiedImageError): return lines, None, None if is_bitonal(im): im = im.convert('1') - seg_key = 'lines' if 'lines' in xml_record else 'boxes' - recs = xml_record.pop(seg_key) + recs = xml_record.lines.values() for idx, rec in enumerate(recs): + seg = Segmentation(text_direction='horizontal-lr', + imagename=xml_record.imagename, + type=xml_record.type, + lines=[rec], + regions=None, + script_detection=False, + line_orders=None) try: - line_im, line = next(extract_polygons(im, {**xml_record, seg_key: [rec]})) + line_im, line = next(extract_polygons(im, seg)) except KrakenInputException: logger.warning(f'Invalid line {idx} in {im.filename}') continue except Exception as e: logger.warning(f'Unexpected exception {e} from line {idx} in {im.filename}') continue - if not line['text'] and skip_empty_lines: + if not line.text and skip_empty_lines: continue fp = io.BytesIO() line_im.save(fp, format='png') - lines.append({'text': line['text'], 'im': fp.getvalue()}) + lines.append({'text': line.text, 'im': fp.getvalue()}) return lines, im.mode @@ -90,6 +97,7 @@ def parse_path(path: Union[str, PathLike], gt = fp.read().strip('\n\r') if not gt and skip_empty_lines: raise KrakenInputException(f'No text for ground truth line {path}.') + return {'image': path, 'lines': [{'text': gt}]} @@ -135,12 +143,8 @@ def build_binary_dataset(files: Optional[List[Union[str, PathLike, Dict]]] = Non logger.info('Parsing XML files') extract_fn = partial(_extract_line, skip_empty_lines=skip_empty_lines) parse_fn = None - if format_type == 'xml': - parse_fn = parse_xml - elif format_type == 'alto': - parse_fn = parse_alto - elif format_type == 'page': - parse_fn = parse_page + if format_type in ['xml', 'alto', 'page']: + parse_fn = XMLPage elif format_type == 'path': if not ignore_splits: logger.warning('ignore_splits is False and format_type is path. Will not serialize splits.') @@ -163,10 +167,13 @@ def build_binary_dataset(files: Optional[List[Union[str, PathLike, Dict]]] = Non logger.warning(f'Invalid input file {doc}') continue try: - name_ext = str(data['image']).split(extsep, 1) - if name_ext[1] == 'gt.txt': - data['image'] = name_ext[0] + '.png' - with open(data['image'], 'rb') as fp: + if format_type in ['xml', 'alto', 'page']: + imagename = data.imagename + else: + name_ext = str(data['image']).split(extsep, 1) + imagename = name_ext[0] + '.png' + data['image'] = imagename + with open(imagename, 'rb') as fp: Image.open(fp) except (FileNotFoundError, UnidentifiedImageError) as e: logger.warning(f'Could not open file {e.filename} in {doc}') @@ -181,9 +188,13 @@ def build_binary_dataset(files: Optional[List[Union[str, PathLike, Dict]]] = Non alphabet = Counter() num_lines = 0 for doc in docs: - for line in doc['lines']: + if format_type in ['xml', 'alto', 'page']: + lines = doc.lines.values() + else: + lines = doc['lines'] + for line in lines: num_lines += 1 - alphabet.update(line['text']) + alphabet.update(line.text if format_type in ['xml', 'alto', 'page'] else line['text']) callback(0, num_lines) diff --git a/kraken/lib/dataset/__init__.py b/kraken/lib/dataset/__init__.py index 960ef8499..c6710d24c 100644 --- a/kraken/lib/dataset/__init__.py +++ b/kraken/lib/dataset/__init__.py @@ -17,4 +17,5 @@ """ from .recognition import ArrowIPCRecognitionDataset, PolygonGTDataset, GroundTruthDataset # NOQA from .segmentation import BaselineSet # NOQA +from .ro import PairWiseROSet, PageWiseROSet #NOQA from .utils import ImageInputTransforms, collate_sequences, global_align, compute_confusions # NOQA diff --git a/kraken/lib/dataset/recognition.py b/kraken/lib/dataset/recognition.py index d381b9d97..0ddb5585e 100644 --- a/kraken/lib/dataset/recognition.py +++ b/kraken/lib/dataset/recognition.py @@ -30,6 +30,7 @@ from torch.utils.data import Dataset from typing import Dict, List, Tuple, Callable, Optional, Any, Union, Literal +from kraken.containers import BaselineLine, BBoxLine, Segmentation from kraken.lib.util import is_bitonal from kraken.lib.codec import PytorchCodec from kraken.lib.segmentation import extract_polygons @@ -56,7 +57,7 @@ def __init__(self): ShiftScaleRotate, OpticalDistortion, ElasticTransform, PixelDropout ) - + self._transforms = Compose([ ToFloat(), PixelDropout(p=0.2), @@ -71,7 +72,7 @@ def __init__(self): ElasticTransform(alpha=64, sigma=25, alpha_affine=0.25, p=0.1), ], p=0.2), ], p=0.5) - + def __call__(self, image): return self._transforms(image=image) @@ -319,54 +320,67 @@ def __init__(self, self.im_mode = '1' - def add(self, *args, **kwargs): + def add(self, + line: Optional[BaselineLine] = None, + page: Optional[Segmentation] = None): """ - Adds a line to the dataset. + Adds an indiviual line or all lines on a page to the dataset. Args: - im (path): Path to the whole page image - text (str): Transcription of the line. - baseline (list): A list of coordinates [[x0, y0], ..., [xn, yn]]. - boundary (list): A polygon mask for the line. + line: BaselineLine container object of a line. + page: Segmentation container object for a page. """ - if 'preparse' not in kwargs or not kwargs['preparse']: - kwargs = self.parse(*args, **kwargs) - self._images.append((kwargs['image'], kwargs['baseline'], kwargs['boundary'])) - self._gt.append(kwargs['text']) - self.alphabet.update(kwargs['text']) + if line: + self.add_line(line) + if page: + self.add_page(page) + if not (line and page): + raise ValueError('Neither line nor page data provided in dataset builder') - def parse(self, - image: Union[PathLike, str, Image.Image], - text: str, - baseline: List[Tuple[int, int]], - boundary: List[Tuple[int, int]], - *args, - **kwargs): + def add_page(self, page: Segmentation): + """ + Adds all lines on a page to the dataset. + + Invalid lines will be skipped and a warning will be printed. + + Args: + page: Segmentation container object for a page. """ - Parses a sample for the dataset and returns it. + if page.type != 'baselines': + raise ValueError(f'Invalid segmentation of type {page.type} (expected "baselines")') + for line in page.lines: + try: + self.add_line(dataclasses.replace(line, imagename=page.imagename)) + except ValueError as e: + logger.warning(e) - This function is mainly uses for parallelized loading of training data. + def add_line(self, line: BaselineLine): + """ + Adds a line to the dataset. Args: - im (path): Path to the whole page image - text (str): Transcription of the line. - baseline (list): A list of coordinates [[x0, y0], ..., [xn, yn]]. - boundary (list): A polygon mask for the line. + line: BaselineLine container object for a line. + + Raises: + ValueError if the transcription of the line is empty after + transformation or either baseline or bounding polygon are missing. """ - orig_text = text + if line.type != 'baselines': + raise ValueError(f'Invalid line of type {line.type} (expected "baselines")') + + text = line.text for func in self.text_transforms: text = func(text) if not text and self.skip_empty_lines: - raise KrakenInputException(f'Text line "{orig_text}" is empty after transformations') - if not baseline: - raise KrakenInputException('No baseline given for line') - if not boundary: - raise KrakenInputException('No boundary given for line') - return {'text': text, - 'image': image, - 'baseline': baseline, - 'boundary': boundary, - 'preparse': True} + raise ValueError(f'Text line "{line.text}" is empty after transformations') + if not line.baseline: + raise ValueError('No baseline given for line') + if not line.boundary: + raise ValueError('No boundary given for line') + + self._images.append((line.image, line.baseline, line.boundary)) + self._gt.append(text) + self.alphabet.update(text) def encode(self, codec: Optional[PytorchCodec] = None) -> None: """ @@ -432,8 +446,7 @@ class GroundTruthDataset(Dataset): All data is cached in memory. """ - def __init__(self, split: Callable[[Union[PathLike, str]], str] = F_t.default_split, - suffix: str = '.gt.txt', + def __init__(self, normalization: Optional[str] = None, whitespace_normalization: bool = True, skip_empty_lines: bool = True, @@ -444,10 +457,6 @@ def __init__(self, split: Callable[[Union[PathLike, str]], str] = F_t.default_sp Reads a list of image-text pairs and creates a ground truth set. Args: - split: Function for generating the base name without - extensions from paths - suffix: Suffix to attach to image base name for text - retrieval mode: Image color space. Either RGB (color) or L (grayscale/bw). Only L is compatible with vertical scaling/dewarping. @@ -466,8 +475,6 @@ def __init__(self, split: Callable[[Union[PathLike, str]], str] = F_t.default_sp tensor suitable for forward passes. augmentation: Enables augmentation. """ - self.suffix = suffix - self.split = partial(F_t.suffix_split, split=split, suffix=suffix) self._images = [] # type: Union[List[Image], List[torch.Tensor]] self._gt = [] # type: List[str] self.alphabet = Counter() # type: Counter @@ -493,35 +500,67 @@ def __init__(self, split: Callable[[Union[PathLike, str]], str] = F_t.default_sp self.im_mode = '1' - def add(self, *args, **kwargs) -> None: + def add(self, + line: Optional[BBoxLine] = None, + page: Optional[Segmentation] = None): """ - Adds a line-image-text pair to the dataset. + Adds an indiviual line or all lines on a page to the dataset. Args: - image (str): Input image path + line: BBoxLine container object of a line. + page: Segmentation container object for a page. """ - if 'preparse' not in kwargs or not kwargs['preparse']: - kwargs = self.parse(*args, **kwargs) - self._images.append(kwargs['image']) - self._gt.append(kwargs['text']) - self.alphabet.update(kwargs['text']) + if line: + self.add_line(line) + if page: + self.add_page(page) + if not (line and page): + raise ValueError('Neither line nor page data provided in dataset builder') - def parse(self, image: Union[PathLike, str, Image.Image], *args, **kwargs) -> Dict: + def add_page(self, page: Segmentation): """ - Parses a sample for this dataset. + Adds all lines on a page to the dataset. - This is mostly used to parallelize populating the dataset. + Invalid lines will be skipped and a warning will be printed. Args: - image (str): Input image path - """ - with open(self.split(image), 'r', encoding='utf-8') as fp: - text = fp.read().strip('\n\r') - for func in self.text_transforms: - text = func(text) - if not text and self.skip_empty_lines: - raise KrakenInputException(f'Text line is empty ({fp.name})') - return {'image': image, 'text': text, 'preparse': True} + page: Segmentation container object for a page. + """ + if page.type != 'bbox': + raise ValueError(f'Invalid segmentation of type {page.type} (expected "bbox")') + for line in page.lines: + try: + self.add_line(dataclasses.replace(line, imagename=page.imagename)) + except ValueError as e: + logger.warning(e) + + def add_line(self, line: BBoxLine): + """ + Adds a line to the dataset. + + Args: + line: BBoxLine container object for a line. + + Raises: + ValueError if the transcription of the line is empty after + transformation or either baseline or bounding polygon are missing. + """ + if line.type != 'bbox': + raise ValueError(f'Invalid line of type {line.type} (expected "bbox")') + + text = line.text + for func in self.text_transforms: + text = func(text) + if not text and self.skip_empty_lines: + raise ValueError(f'Text line "{line.text}" is empty after transformations') + if not line.baseline: + raise ValueError('No baseline given for line') + if not line.boundary: + raise ValueError('No boundary given for line') + + self._images.append(line.image) + self._gt.append(text) + self.alphabet.update(text) def encode(self, codec: Optional[PytorchCodec] = None) -> None: """ diff --git a/kraken/lib/dataset/ro.py b/kraken/lib/dataset/ro.py new file mode 100644 index 000000000..3edb82b88 --- /dev/null +++ b/kraken/lib/dataset/ro.py @@ -0,0 +1,256 @@ +# +# Copyright 2015 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Utility functions for data loading and training of VGSL networks. +""" +import json +import torch +import traceback +import numpy as np +import torch.nn.functional as F +import shapely.geometry as geom + +from math import factorial +from os import path, PathLike +from PIL import Image +from shapely.ops import split +from itertools import groupby +from torchvision import transforms +from collections import defaultdict +from torch.utils.data import Dataset +from typing import Dict, List, Tuple, Sequence, Callable, Any, Union, Literal, Optional + +from kraken.lib.xml import XMLPage + +from kraken.lib.exceptions import KrakenInputException + +__all__ = ['PairWiseROSet', 'PageWiseROSet'] + +import logging + +logger = logging.getLogger(__name__) + + +class PairWiseROSet(Dataset): + """ + Dataset for training a reading order determination model. + + Returns random pairs of lines from the same page. + """ + def __init__(self, files: Sequence[Union[PathLike, str]] = None, + mode: Optional[Literal['alto', 'page', 'xml']] = 'path', + level: Literal['regions', 'baselines'] = 'baselines', + ro_id: Optional[str] = None, + class_mapping: Optional[Dict[str, int]] = None): + """ + Samples pairs lines/regions from XML files for training a reading order + model . + + Args: + mode: Either alto, page, xml, None. In alto, page, and xml + mode the baseline paths and image data is retrieved from an + ALTO/PageXML file. In `None` mode data is iteratively added + through the `add` method. + ro_id: ID of the reading order to sample from. Defaults to + `line_implicit`/`region_implicit`. + """ + super().__init__() + + self._num_pairs = 0 + self.failed_samples = [] + if class_mapping: + self.class_mapping = class_mapping + self.num_classes = len(class_mapping) + 1 + else: + self.num_classes = 1 + self.class_mapping = {} + + self.data = [] + + if mode in ['alto', 'page', 'xml']: + for file in files: + try: + doc = XMLPage(file, filetype=mode) + for tag in doc.tags: + if tag not in self.class_mapping: + self.class_mapping[tag] = self.num_classes + self.num_classes += 1 + except KrakenInputException as e: + files.pop(file) + logger.warning(e) + continue + for file in files: + try: + doc = XMLPage(file, filetype=mode) + if level == 'baselines': + if not ro_id: + ro_id = 'line_implicit' + order = doc.get_sorted_lines(ro_id) + elif level == 'regions': + if not ro_id: + ro_id = 'region_implicit' + order = doc.get_sorted_regions(ro_id) + else: + raise ValueError(f'Invalid RO type {level}') + # traverse RO and substitute features. + w, h = Image.open(doc.imagename).size + sorted_lines = [] + for line in order: + line_coords = np.array(line.baseline) / (w, h) + line_center = np.mean(line_coords, axis=0) + cl = torch.zeros(self.num_classes, dtype=torch.float) + # if class is not in class mapping default to None class (idx 0) + cl[self.class_mapping.get(line.tags['type'], 0)] = 1 + line_data = {'type': line.tags['type'], + 'features': torch.cat((cl, # one hot encoded line type + torch.tensor(line_center, dtype=torch.float), # line center + torch.tensor(line_coords[0, :], dtype=torch.float), # start_point coord + torch.tensor(line_coords[-1, :], dtype=torch.float), # end point coord) + )) + } + sorted_lines.append(line_data) + if len(sorted_lines) > 1: + self.data.append(sorted_lines) + self._num_pairs += int(factorial(len(sorted_lines))/factorial(len(sorted_lines)-2)) + else: + logger.info(f'Page {doc} has less than 2 lines. Skipping') + except KrakenInputException as e: + logger.warning(e) + continue + else: + raise Exception('invalid dataset mode') + + def __getitem__(self, idx): + lines = [] + while len(lines) < 2: + lines = self.data[torch.randint(len(self.data), (1,))[0]] + idx0, idx1 = 0, 0 + while idx0 == idx1: + idx0, idx1 = torch.randint(len(lines), (2,)) + x = torch.cat((lines[idx0]['features'], lines[idx1]['features'])) + y = torch.tensor(0 if idx0 >= idx1 else 1, dtype=torch.float) + return {'sample': x, 'target': y} + + def get_feature_dim(self): + return 2 * self.num_classes + 12 + + def __len__(self): + return self._num_pairs + + +class PageWiseROSet(Dataset): + """ + Dataset for training a reading order determination model. + + Returns all lines from the same page. + """ + def __init__(self, files: Sequence[Union[PathLike, str]] = None, + mode: Optional[Literal['alto', 'page', 'xml']] = 'path', + level: Literal['regions', 'baselines'] = 'baselines', + ro_id: Optional[str] = None, + class_mapping: Optional[Dict[str, int]] = None): + """ + Samples pairs lines/regions from XML files for training a reading order + model . + + Args: + mode: Either alto, page, xml, None. In alto, page, and xml + mode the baseline paths and image data is retrieved from an + ALTO/PageXML file. In `None` mode data is iteratively added + through the `add` method. + ro_id: ID of the reading order to sample from. Defaults to + `line_implicit`/`region_implicit`. + """ + super().__init__() + + self.failed_samples = [] + if class_mapping: + self.class_mapping = class_mapping + self.num_classes = len(class_mapping) + 1 + else: + self.num_classes = 1 + self.class_mapping = {} + + self.data = [] + + if mode in ['alto', 'page', 'xml']: + for file in files: + try: + doc = XMLPage(file, filetype=mode) + for tag in doc.tags: + if tag not in self.class_mapping: + self.class_mapping[tag] = self.num_classes + self.num_classes += 1 + except KrakenInputException as e: + files.pop(file) + logger.warning(e) + continue + for file in files: + try: + doc = XMLPage(file, filetype=mode) + if level == 'baselines': + if not ro_id: + ro_id = 'line_implicit' + order = doc.get_sorted_lines(ro_id) + elif level == 'regions': + if not ro_id: + ro_id = 'region_implicit' + order = doc.get_sorted_regions(ro_id) + else: + raise ValueError(f'Invalid RO type {level}') + # traverse RO and substitute features. + w, h = Image.open(doc.imagename).size + sorted_lines = [] + for line in order: + line_coords = np.array(line.baseline) / (w, h) + line_center = np.mean(line_coords, axis=0) + cl = torch.zeros(self.num_classes, dtype=torch.float) + # if class is not in class mapping default to None class (idx 0) + cl[self.class_mapping.get(line.tags['type'], 0)] = 1 + line_data = {'type': line.tags['type'], + 'features': torch.cat((cl, # one hot encoded line type + torch.tensor(line_center, dtype=torch.float), # line center + torch.tensor(line_coords[0, :], dtype=torch.float), # start_point coord + torch.tensor(line_coords[-1, :], dtype=torch.float), # end point coord) + )) + } + sorted_lines.append(line_data) + if len(sorted_lines) > 1: + self.data.append(sorted_lines) + else: + logger.info(f'Page {doc} has less than 2 lines. Skipping') + except KrakenInputException as e: + logger.warning(e) + continue + else: + raise Exception('invalid dataset mode') + + def __getitem__(self, idx): + xs = [] + ys = [] + for i in range(len(self.data[idx])): + for j in range(len(self.data[idx])): + if i == j and len(self.data[idx]) != 1: + continue + xs.append(torch.cat((self.data[idx][i]['features'], + self.data[idx][j]['features']))) + ys.append(torch.tensor(0 if i >= j else 1, dtype=torch.float)) + return {'sample': torch.stack(xs), 'target': torch.stack(ys), 'num_lines': len(self.data[idx])} + + def get_feature_dim(self): + return 2 * self.num_classes + 12 + + def __len__(self): + return len(self.data) diff --git a/kraken/lib/dataset/segmentation.py b/kraken/lib/dataset/segmentation.py index 075507cbf..0d248800e 100644 --- a/kraken/lib/dataset/segmentation.py +++ b/kraken/lib/dataset/segmentation.py @@ -33,7 +33,8 @@ from skimage.draw import polygon -from kraken.lib.xml import parse_alto, parse_page, parse_xml +from kraken.containers import Segmentation +from kraken.lib.xml import XMLPage from kraken.lib.exceptions import KrakenInputException @@ -48,24 +49,20 @@ class BaselineSet(Dataset): """ Dataset for training a baseline/region segmentation model. """ - def __init__(self, imgs: Sequence[Union[PathLike, str]] = None, - suffix: str = '.path', + def __init__(self, line_width: int = 4, padding: Tuple[int, int, int, int] = (0, 0, 0, 0), im_transforms: Callable[[Any], torch.Tensor] = transforms.Compose([]), - mode: Optional[Literal['path', 'alto', 'page', 'xml']] = 'path', + mode: Optional[Literal['alto', 'page', 'xml']] = 'xml', augmentation: bool = False, valid_baselines: Sequence[str] = None, merge_baselines: Dict[str, Sequence[str]] = None, valid_regions: Sequence[str] = None, merge_regions: Dict[str, Sequence[str]] = None): """ - Reads a list of image-json pairs and creates a data set. + Creates a dataset for a text-line and region segmentation model. Args: - imgs: - suffix: Suffix to attach to image base name to load JSON files - from. line_width: Height of the baseline in the scaled input. padding: Tuple of ints containing the left/right, top/bottom padding of the input images. @@ -90,7 +87,6 @@ def __init__(self, imgs: Sequence[Union[PathLike, str]] = None, self.mode = mode self.im_mode = '1' self.pad = padding - self.aug = None self.targets = [] # n-th entry contains semantic of n-th class self.class_mapping = {'aux': {'_start_separator': 0, '_end_separator': 1}, 'baselines': {}, 'regions': {}} @@ -102,60 +98,8 @@ def __init__(self, imgs: Sequence[Union[PathLike, str]] = None, self.mreg_dict = merge_regions if merge_regions is not None else {} self.valid_baselines = valid_baselines self.valid_regions = valid_regions - if mode in ['alto', 'page', 'xml']: - if mode == 'alto': - fn = parse_alto - elif mode == 'page': - fn = parse_page - elif mode == 'xml': - fn = parse_xml - im_paths = [] - self.targets = [] - for img in imgs: - try: - data = fn(img) - im_paths.append(data['image']) - lines = defaultdict(list) - for line in data['lines']: - if valid_baselines is None or set(line['tags'].values()).intersection(valid_baselines): - tags = set(line['tags'].values()).intersection(valid_baselines) if valid_baselines else line['tags'].values() - for tag in tags: - lines[self.mbl_dict.get(tag, tag)].append(line['baseline']) - self.class_stats['baselines'][self.mbl_dict.get(tag, tag)] += 1 - regions = defaultdict(list) - for k, v in data['regions'].items(): - if valid_regions is None or k in valid_regions: - regions[self.mreg_dict.get(k, k)].extend(v) - self.class_stats['regions'][self.mreg_dict.get(k, k)] += len(v) - data['regions'] = regions - self.targets.append({'baselines': lines, 'regions': data['regions']}) - except KrakenInputException as e: - logger.warning(e) - continue - # get line types - imgs = im_paths - # calculate class mapping - line_types = set() - region_types = set() - for page in self.targets: - for line_type in page['baselines'].keys(): - line_types.add(line_type) - for reg_type in page['regions'].keys(): - region_types.add(reg_type) - idx = -1 - for idx, line_type in enumerate(line_types): - self.class_mapping['baselines'][line_type] = idx + self.num_classes - self.num_classes += idx + 1 - idx = -1 - for idx, reg_type in enumerate(region_types): - self.class_mapping['regions'][reg_type] = idx + self.num_classes - self.num_classes += idx + 1 - elif mode == 'path': - pass - elif mode is None: - imgs = [] - else: - raise Exception('invalid dataset mode') + + self.aug = None if augmentation: import cv2 cv2.setNumThreads(0) @@ -179,37 +123,26 @@ def __init__(self, imgs: Sequence[Union[PathLike, str]] = None, ], p=0.2), HueSaturationValue(hue_shift_limit=20, sat_shift_limit=0.1, val_shift_limit=0.1, p=0.3), ], p=0.5) - self.imgs = imgs self.line_width = line_width self.transforms = im_transforms self.seg_type = None - def add(self, - image: Union[PathLike, str, Image.Image], - baselines: List[List[List[Tuple[int, int]]]] = None, - regions: Dict[str, List[List[Tuple[int, int]]]] = None, - *args, - **kwargs): + def add(self, doc: Union[Segmentation, XMLPage]): """ Adds a page to the dataset. Args: - im: Path to the whole page image - baseline: A list containing dicts with a list of coordinates - and tags [{'baseline': [[x0, y0], ..., - [xn, yn]], 'tags': ('script_type',)}, ...] - regions: A dict containing list of lists of coordinates - {'region_type_0': [[x0, y0], ..., [xn, yn]]], - 'region_type_1': ...}. + doc: Either a Segmentation container class or an XMLPage. """ - if self.mode: - raise Exception(f'The `add` method is incompatible with dataset mode {self.mode}') + if doc.type != 'baselines': + raise ValueError(f'{doc} is of type {doc.type}. Expected "baselines".') + baselines_ = defaultdict(list) - for line in baselines: - if self.valid_baselines is None or set(line['tags'].values()).intersection(self.valid_baselines): - tags = set(line['tags'].values()).intersection(self.valid_baselines) if self.valid_baselines else line['tags'].values() + for line in doc.lines: + if self.valid_baselines is None or set(line.tags.values()).intersection(self.valid_baselines): + tags = set(line.tags.values()).intersection(self.valid_baselines) if self.valid_baselines else line.tags.values() for tag in tags: - baselines_[tag].append(line['baseline']) + baselines_[tag].append(line.baseline) self.class_stats['baselines'][tag] += 1 if tag not in self.class_mapping['baselines']: @@ -217,7 +150,7 @@ def add(self, self.class_mapping['baselines'][tag] = self.num_classes - 1 regions_ = defaultdict(list) - for k, v in regions.items(): + for k, v in doc.regions.items(): reg_type = self.mreg_dict.get(k, k) if self.valid_regions is None or reg_type in self.valid_regions: regions_[reg_type].extend(v) @@ -231,11 +164,7 @@ def add(self, def __getitem__(self, idx): im = self.imgs[idx] - if self.mode != 'path': - target = self.targets[idx] - else: - with open('{}.path'.format(path.splitext(im)[0]), 'r') as fp: - target = json.load(fp) + target = self.targets[idx] if not isinstance(im, Image.Image): try: logger.debug(f'Attempting to load {im}') diff --git a/kraken/lib/default_specs.py b/kraken/lib/default_specs.py index 4830ee1fb..af08fd1e5 100644 --- a/kraken/lib/default_specs.py +++ b/kraken/lib/default_specs.py @@ -19,6 +19,30 @@ SEGMENTATION_SPEC = '[1,1800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]' # NOQA RECOGNITION_SPEC = '[1,120,0,1 Cr3,13,32 Do0.1,2 Mp2,2 Cr3,13,32 Do0.1,2 Mp2,2 Cr3,9,64 Do0.1,2 Mp2,2 Cr3,9,64 Do0.1,2 S1(1x0)1,3 Lbx200 Do0.1,2 Lbx200 Do0.1,2 Lbx200 Do]' # NOQA +READING_ORDER_HYPER_PARAMS = {'lrate': 0.001, + 'freq': 1.0, + 'batch_size': 15000, + 'epochs': 3000, + 'lag': 300, + 'min_delta': None, + 'quit': 'early', + 'optimizer': 'Adam', + 'momentum': 0.9, + 'weight_decay': 0.01, + 'schedule': 'cosine', + 'completed_epochs': 0, + # lr scheduler params + # step/exp decay + 'step_size': 10, + 'gamma': 0.1, + # reduce on plateau + 'rop_factor': 0.1, + 'rop_patience': 5, + # cosine + 'cos_t_max': 100, + 'warmup': 0, + } + RECOGNITION_PRETRAIN_HYPER_PARAMS = {'pad': 16, 'freq': 1.0, 'batch_size': 64, @@ -84,7 +108,7 @@ SEGMENTATION_HYPER_PARAMS = {'line_width': 8, 'padding': (0, 0), 'freq': 1.0, - 'quit': 'dumb', + 'quit': 'fixed', 'epochs': 50, 'min_epochs': 0, 'lag': 10, diff --git a/kraken/lib/models.py b/kraken/lib/models.py index c8e08dcfb..ac1219b23 100644 --- a/kraken/lib/models.py +++ b/kraken/lib/models.py @@ -217,6 +217,6 @@ def validate_hyper_parameters(hyper_params): """ Validate some model's hyper parameters and modify them in place if need be. """ - if (hyper_params['quit'] == 'dumb' and hyper_params['completed_epochs'] >= hyper_params['epochs']): + if (hyper_params['quit'] == 'fixed' and hyper_params['completed_epochs'] >= hyper_params['epochs']): logger.warning('Maximum epochs reached (might be loaded from given model), starting again from 0.') hyper_params['completed_epochs'] = 0 diff --git a/kraken/lib/pretrain/model.py b/kraken/lib/pretrain/model.py index 3b5087e96..61dc04ab4 100644 --- a/kraken/lib/pretrain/model.py +++ b/kraken/lib/pretrain/model.py @@ -45,7 +45,7 @@ from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda from kraken.lib import vgsl, default_specs, layers -from kraken.lib.xml import preparse_xml_data +from kraken.lib.xml import XMLPage from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import (ArrowIPCRecognitionDataset, GroundTruthDataset, PolygonGTDataset, @@ -108,10 +108,10 @@ def __init__(self, valid_norm = True if format_type in ['xml', 'page', 'alto']: logger.info(f'Parsing {len(training_data)} XML files for training data') - training_data = preparse_xml_data(training_data, format_type, repolygonize) + training_data = [{'page': XMLPage(file, format_type).to_container()} for file in training_data] if evaluation_data: logger.info(f'Parsing {len(evaluation_data)} XML files for validation data') - evaluation_data = preparse_xml_data(evaluation_data, format_type, repolygonize) + evaluation_data = [{'page': XMLPage(file, format_type).to_container()} for file in evaluation_data] if binary_dataset_split: logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') binary_dataset_split = False @@ -144,7 +144,7 @@ def __init__(self, valid_norm = True # format_type is None. Determine training type from length of training data entry elif not format_type: - if len(training_data[0]) >= 4: + if training_data[0].type == 'baselines': DatasetClass = PolygonGTDataset valid_norm = False else: @@ -156,6 +156,22 @@ def __init__(self, if binary_dataset_split: logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') binary_dataset_split = False + samples = [] + for sample in training_data: + if isinstance(sample, Segmentation): + samples.append({'page': sample}) + else: + samples.append({'line': sample}) + training_data = samples + if evaluation_data: + samples = [] + for sample in evaluation_data: + if isinstance(sample, Segmentation): + samples.append({'page': sample}) + else: + samples.append({'line': sample}) + evaluation_data = samples + else: raise ValueError(f'format_type {format_type} not in [alto, page, xml, path, binary].') @@ -203,18 +219,12 @@ def _build_dataset(self, skip_empty_lines=False, **kwargs) - if (self.hparams.num_workers and self.hparams.num_workers > 1) and self.hparams.format_type != 'binary': - with Pool(processes=self.hparams.num_workers) as pool: - for im in pool.imap_unordered(partial(_star_fun, dataset.parse), training_data, 5): - logger.debug(f'Adding sample {im} to training set') - if im: - dataset.add(**im) - else: - for im in training_data: - try: - dataset.add(**im) - except KrakenInputException as e: - logger.warning(str(e)) + for sample in training_data: + try: + dataset.add(**sample) + except KrakenInputException as e: + logger.warning(str(e)) + return dataset def train_dataloader(self): diff --git a/kraken/lib/progress.py b/kraken/lib/progress.py index bb2e30b3e..344864d01 100644 --- a/kraken/lib/progress.py +++ b/kraken/lib/progress.py @@ -128,7 +128,7 @@ def _init_progress(self, trainer): def _get_train_description(self, current_epoch: int) -> str: return f"stage {current_epoch}/" \ - f"{self.trainer.max_epochs if self.trainer.model.hparams['quit'] == 'fixed' else '∞'}" + f"{self.trainer.max_epochs if self.trainer.model.hparams.hyper_params['quit'] == 'fixed' else '∞'}" @dataclass class RichProgressBarTheme: @@ -155,4 +155,3 @@ class RichProgressBarTheme: time: Union[str, Style] = DEFAULT_STYLES['progress.elapsed'] processing_speed: Union[str, Style] = DEFAULT_STYLES['progress.data.speed'] metrics: Union[str, Style] = DEFAULT_STYLES['progress.description'] - diff --git a/kraken/lib/ro/__init__.py b/kraken/lib/ro/__init__.py new file mode 100644 index 000000000..4e370b855 --- /dev/null +++ b/kraken/lib/ro/__init__.py @@ -0,0 +1,19 @@ +# +# Copyright 2023 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Tools for trainable reading order. +""" + +from .model import ROModel # NOQA diff --git a/kraken/lib/ro/layers.py b/kraken/lib/ro/layers.py new file mode 100644 index 000000000..a18f3de1e --- /dev/null +++ b/kraken/lib/ro/layers.py @@ -0,0 +1,75 @@ +""" +Layers for VGSL models +""" +import torch +from torch import nn + +from typing import Tuple + +# all tensors are ordered NCHW, the "feature" dimension is C, so the output of +# an LSTM will be put into C same as the filters of a CNN. + +__all__ = ['MLP'] + + +class MLP(nn.Module): + """ + A simple 2 layer MLP for reading order determination. + """ + def __init__(self, feature_size: int, hidden_size: int): + super(MLP, self).__init__() + self.fc1 = nn.Linear(feature_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, 1) + self.feature_size = feature_size + self.hidden_size = hidden_size + self.class_mapping = None + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + return self.fc2(x) + + def get_shape(self, input: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]: + """ + Calculates the output shape from input 4D tuple NCHW. + """ + return input + + def get_spec(self, name) -> "VGSLBlock": + """ + Generates a VGSL spec block from the layer instance. + """ + return f'[1,0,0,1 RO{{{name}}}{self.feature_size},{self.hidden_size}]' + + def deserialize(self, name: str, spec) -> None: + """ + Sets the weights of an initialized module from a CoreML protobuf spec. + """ + # extract 1st linear projection parameters + lin = [x for x in spec.neuralNetwork.layers if x.name == '{}_mlp_lin_0'.format(name)][0].innerProduct + weights = torch.Tensor(lin.weights.floatValue).resize_as_(self.fc1.weight.data) + bias = torch.Tensor(lin.bias.floatValue) + self.fc1.weight = torch.nn.Parameter(weights) + self.fc1.bias = torch.nn.Parameter(bias) + # extract 2nd linear projection parameters + lin = [x for x in spec.neuralNetwork.layers if x.name == '{}_mlp_lin_1'.format(name)][0].innerProduct + weights = torch.Tensor(lin.weights.floatValue).resize_as_(self.fc2.weight.data) + bias = torch.Tensor(lin.bias.floatValue) + self.fc2.weight = torch.nn.Parameter(weights) + self.fc2.bias = torch.nn.Parameter(bias) + + def serialize(self, name: str, input: str, builder): + """ + Serializes the module using a NeuralNetworkBuilder. + """ + builder.add_inner_product(f'{name}_mlp_lin_0', self.fc1.weight.data.numpy(), + self.fc1.bias.data.numpy(), + self.feature_size, self.hidden_size, + has_bias=True, input_name=input, output_name=f'{name}_mlp_lin_0') + builder.add_activation(f'{name}_mlp_lin_0_relu', 'RELU', f'{name}_mlp_lin_0', f'{name}_mlp_lin_0_relu') + builder.add_inner_product(f'{name}_mlp_lin_1', self.fc2.weight.data.numpy(), + self.fc2.bias.data.numpy(), + self.hidden_size, 1, + has_bias=True, input_name=f'{name}_mlp_lin_0_relu', output_name=f'{name}_mlp_lin_1') + return name diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py new file mode 100644 index 000000000..4fa957144 --- /dev/null +++ b/kraken/lib/ro/model.py @@ -0,0 +1,248 @@ +# +# Copyright 2023 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Pytorch-lightning modules for reading order training. + +Adapted from: +""" +import re +import math +import torch +import logging +import numpy as np +import torch.nn.functional as F +import pytorch_lightning as pl + +from os import PathLike +from torch.optim import lr_scheduler +from dataclasses import dataclass, field +from torch.nn import Module +from typing import Dict, Optional, Sequence, Union, Any, Literal, List + +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor + +from kraken.lib import vgsl, default_specs, layers +from kraken.lib.dataset import PairWiseROSet, PageWiseROSet +from kraken.lib.train import _configure_optimizer_and_lr_scheduler +from kraken.lib.segmentation import _greedy_order_decoder +from kraken.lib.ro.layers import MLP + +from torch.utils.data import DataLoader, random_split, Subset + + +logger = logging.getLogger(__name__) + +@dataclass +class DummyVGSLModel: + hyper_params: Dict[str, int] = field(default_factory=dict) + user_metadata: Dict[str, List] = field(default_factory=dict) + one_channel_mode: Literal['1', 'L'] = '1' + ptl_module: Module = None + model_type: str = 'unknown' + + def __post_init__(self): + self.hyper_params: Dict[str, int] = {'completed_epochs': 0} + self.user_metadata: Dict[str, List] = {'accuracy': [], 'metrics': []} + + def save_model(self, filename): + self.ptl_module.save_checkpoint(filename) + + +def spearman_footrule_distance(s, t): + return (s - t).abs().sum() / (0.5 * (len(s) ** 2 - (len(s) % 2))) + + +class ROModel(pl.LightningModule): + def __init__(self, + hyper_params: Dict[str, Any] = None, + output: str = 'model', + training_data: Union[Sequence[Union[PathLike, str]], Sequence[Dict[str, Any]]] = None, + evaluation_data: Optional[Union[Sequence[Union[PathLike, str]], Sequence[Dict[str, Any]]]] = None, + partition: Optional[float] = 0.9, + num_workers: int = 1, + format_type: Literal['alto', 'page', 'xml'] = 'xml', + load_hyper_parameters: bool = False, + level: Literal['baselines', 'regions'] = 'baselines', + reading_order: Optional[str] = None): + """ + A LightningModule encapsulating the unsupervised pretraining setup for + a text recognition model. + + Setup parameters (load, training_data, evaluation_data, ....) are + named, model hyperparameters (everything in + `kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS`) are in in the + `hyper_params` argument. + + Args: + hyper_params (dict): Hyperparameter dictionary containing all fields + from + kraken.lib.default_specs.RECOGNITION_PRETRAIN_HYPER_PARAMS + **kwargs: Setup parameters, i.e. CLI parameters of the train() command. + """ + super().__init__() + self.hyper_params = default_specs.READING_ORDER_HYPER_PARAMS + if hyper_params: + self.hyper_params.update(hyper_params) + + if not evaluation_data: + np.random.shuffle(training_data) + training_data = training_data[:int(partition*len(training_data))] + evaluation_data = training_data[int(partition*len(training_data)):] + train_set = PairWiseROSet(training_data, + mode=format_type, + level=level, + ro_id=reading_order) + self.train_set = Subset(train_set, range(len(train_set))) + self.class_mapping = train_set.class_mapping + val_set = PageWiseROSet(evaluation_data, + mode=format_type, + class_mapping=train_set.class_mapping, + level=level, + ro_id=reading_order) + self.val_set = Subset(val_set, range(len(val_set))) + + if len(self.train_set) == 0 or len(self.val_set) == 0: + raise ValueError('No valid training data was provided to the train ' + 'command. Please add valid XML, line, or binary data.') + + logger.info(f'Training set {len(self.train_set)} lines, validation set ' + f'{len(self.val_set)} lines') + + self.output = output + self.criterion = torch.nn.BCEWithLogitsLoss() + + self.num_workers = num_workers + + self.best_epoch = -1 + self.best_metric = torch.inf + + logger.info(f'Creating new RO model') + self.ro_net = MLP(train_set.get_feature_dim(), train_set.get_feature_dim() * 2) + + if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): + logger.debug('Setting multiprocessing tensor sharing strategy to file_system') + torch.multiprocessing.set_sharing_strategy('file_system') + + self.nn = DummyVGSLModel(ptl_module=self) + + self.val_losses = [] + self.val_spearman = [] + + self.save_hyperparameters() + + def forward(self, x): + return F.sigmoid(self.ro_net(x)) + + def validation_step(self, batch, batch_idx): + xs, ys, num_lines = batch['sample'], batch['target'], batch['num_lines'] + logits = self.ro_net(xs).squeeze() + yhat = F.sigmoid(logits) + order = torch.zeros((num_lines, num_lines)) + idx = 0 + for i in range(num_lines): + for j in range(num_lines): + if i != j: + order[i, j] = yhat[idx] + idx += 1 + path = _greedy_order_decoder(order) + spearman_dist = spearman_footrule_distance(torch.tensor(range(num_lines)), path) + self.log('val_spearman', spearman_dist) + loss = self.criterion(logits, ys.squeeze()) + self.val_losses.append(loss.cpu()) + self.val_spearman.append(spearman_dist.cpu()) + + def on_validation_epoch_end(self): + val_metric = np.mean(self.val_spearman) + val_loss = np.mean(self.val_losses) + self.val_spearman.clear() + self.val_losses.clear() + + if val_metric < self.best_metric: + logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {val_metric} ({self.current_epoch})') + self.best_epoch = self.current_epoch + self.best_metric = val_metric + logger.info(f'validation run: val_spearman {val_metric} val_loss {val_loss}') + self.log('val_spearman', val_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_metric', val_metric, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + + def training_step(self, batch, batch_idx): + x, y = batch['sample'], batch['target'] + logits = self.ro_net(x) + loss = self.criterion(logits.squeeze(), y) + self.log('loss', loss) + return loss + + def train_dataloader(self): + return DataLoader(self.train_set, + batch_size=self.hyper_params['batch_size'], + num_workers=self.num_workers, + pin_memory=True) + + def val_dataloader(self): + return DataLoader(self.val_set, + batch_size=1, + num_workers=self.num_workers, + pin_memory=True) + + def save_checkpoint(self, filename): + self.trainer.save_checkpoint(filename) + + def configure_callbacks(self): + callbacks = [] + if self.hparams.hyper_params['quit'] == 'early': + callbacks.append(EarlyStopping(monitor='val_metric', + mode='min', + patience=self.hparams.hyper_params['lag'], + stopping_threshold=0.0)) + if self.hparams.hyper_params['pl_logger']: + callbacks.append(LearningRateMonitor(logging_interval='step')) + return callbacks + + # configuration of optimizers and learning rate schedulers + # -------------------------------------------------------- + # + # All schedulers are created internally with a frequency of step to enable + # batch-wise learning rate warmup. In lr_scheduler_step() calls to the + # scheduler are then only performed at the end of the epoch. + def configure_optimizers(self): + return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, + self.ro_net.parameters(), + len_train_set=len(self.train_set), + loss_tracking_mode='min') + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): + # update params + optimizer.step(closure=optimizer_closure) + + # linear warmup between 0 and the initial learning rate `lrate` in `warmup` + # steps. + if self.hparams.hyper_params['warmup'] and self.trainer.global_step < self.hparams.hyper_params['warmup']: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params['warmup']) + for pg in optimizer.param_groups: + pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] + + def lr_scheduler_step(self, scheduler, metric): + if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: + # step OneCycleLR each batch if not in warmup phase + if isinstance(scheduler, lr_scheduler.OneCycleLR): + scheduler.step() + # step every other scheduler epoch-wise + elif self.trainer.is_last_batch: + if metric is None: + scheduler.step() + else: + scheduler.step(metric) + diff --git a/kraken/lib/ro/util.py b/kraken/lib/ro/util.py new file mode 100644 index 000000000..57fea354b --- /dev/null +++ b/kraken/lib/ro/util.py @@ -0,0 +1,66 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Sequence, Union + +import torch +import random +import numpy as np + + +def positive_integers_with_sum(n, total): + ls = [0] + rv = [] + while len(ls) < n: + c = random.randint(0, total) + ls.append(c) + ls = sorted(ls) + ls.append(total) + for i in range(1, len(ls)): + rv.append(ls[i] - ls[i-1]) + return rv + + +def compute_masks(mask_prob: int, + mask_width: int, + num_neg_samples: int, + seq_lens: Union[torch.Tensor, Sequence[int]]): + """ + Samples num_mask non-overlapping random masks of length mask_width in + sequence of length seq_len. + + Args: + mask_prob: Probability of each individual token being chosen as start + of a masked sequence. Overall number of masks num_masks is + mask_prob * sum(seq_lens) / mask_width. + mask_width: width of each mask + num_neg_samples: Number of samples from unmasked sequence parts (gets + multiplied by num_mask) + seq_lens: sequence lengths + + Returns: + An index array containing 1 for masked bits, 2 for negative samples, + the number of masks, and the actual number of negative samples. + """ + mask_samples = np.zeros(sum(seq_lens)) + num_masks = int(mask_prob * sum(seq_lens.numpy()) // mask_width) + num_neg_samples = num_masks * num_neg_samples + num_masks += num_neg_samples + + indices = [x+mask_width for x in positive_integers_with_sum(num_masks, sum(seq_lens)-num_masks*mask_width)] + start = 0 + mask_slices = [] + for i in indices: + i_start = random.randint(start, i+start-mask_width) + mask_slices.append(slice(i_start, i_start+mask_width)) + start += i + + neg_idx = random.sample(range(len(mask_slices)), num_neg_samples) + neg_slices = [mask_slices.pop(idx) for idx in sorted(neg_idx, reverse=True)] + + mask_samples[np.r_[tuple(mask_slices)]] = 1 + mask_samples[np.r_[tuple(neg_slices)]] = 2 + + return mask_samples, num_masks - num_neg_samples, num_neg_samples diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 6f6459ac5..215c4e5d4 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -16,9 +16,11 @@ Processing for baseline segmenter output """ import PIL +import torch import logging import numpy as np import shapely.geometry as geom +import torch.nn.functional as F from collections import defaultdict @@ -33,12 +35,12 @@ from skimage import draw, filters from skimage.graph import MCP_Connect -from skimage.filters import apply_hysteresis_threshold, sobel +from skimage.filters import sobel from skimage.measure import approximate_polygon, subdivide_polygon, regionprops, label from skimage.morphology import skeletonize from skimage.transform import PiecewiseAffineTransform, SimilarityTransform, AffineTransform, warp -from typing import List, Tuple, Union, Dict, Any, Sequence, Optional +from typing import List, Tuple, Union, Dict, Any, Sequence, Optional, Literal from kraken.lib import default_specs from kraken.lib.exceptions import KrakenInputException @@ -50,7 +52,7 @@ logger = logging.getLogger('kraken') __all__ = ['reading_order', - 'denoising_hysteresis_thresh', + 'neural_reading_order', 'vectorize_lines', 'calculate_polygonal_environment', 'polygonal_reading_order', @@ -60,7 +62,7 @@ 'extract_polygons'] -def reading_order(lines: Sequence[Tuple[slice, slice]], text_direction: str = 'lr') -> np.ndarray: +def reading_order(lines: Sequence[Tuple[slice, slice]], text_direction: Literal['lr', 'rl'] = 'lr') -> np.ndarray: """Given the list of lines (a list of 2D slices), computes the partial reading order. The output is a binary 2D array such that order[i,j] is true if line i comes before line j @@ -131,11 +133,6 @@ def _visit(k): return L -def denoising_hysteresis_thresh(im, low, high, sigma): - im = gaussian_filter(im, sigma) - return apply_hysteresis_threshold(im, low, high) - - def moore_neighborhood(current, backtrack): operations = np.array([[-1, 0], [-1, 1], [0, 1], [1, 1], [1, 0], [1, -1], [0, -1], [-1, -1]]) @@ -740,9 +737,9 @@ def calculate_polygonal_environment(im: PIL.Image.Image = None, return polygons -def polygonal_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]], - text_direction: str = 'lr', - regions: Optional[Sequence[List[Tuple[int, int]]]] = None) -> Sequence[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]: +def polygonal_reading_order(lines: Sequence[Dict], + text_direction: Literal['lr', 'rl'] = 'lr', + regions: Optional[Sequence[geom.Polygon]] = None) -> Sequence[int]: """ Given a list of baselines and regions, calculates the correct reading order and applies it to the input. @@ -755,19 +752,19 @@ def polygonal_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tu Can be 'lr' or 'rl' Returns: - A reordered input. + The indices of the ordered input. """ + lines = [(line['tags']['type'], line['baseline'], line['boundary']) for line in lines] + bounds = [] - if regions is not None: - r = [geom.Polygon(reg) for reg in regions] - else: - r = [] - region_lines = [[] for _ in range(len(r))] + if regions is None: + regions = [] + region_lines = [[] for _ in range(len(regions))] indizes = {} for line_idx, line in enumerate(lines): s_line = geom.LineString(line[1]) in_region = False - for idx, reg in enumerate(r): + for idx, reg in enumerate(regions): if is_in_region(s_line, reg): region_lines[idx].append((line_idx, (slice(s_line.bounds[1], s_line.bounds[3]), slice(s_line.bounds[0], s_line.bounds[2])))) @@ -778,8 +775,8 @@ def polygonal_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tu slice(s_line.bounds[0], s_line.bounds[2]))) indizes[line_idx] = ('line', line) # order everything in regions - intra_region_order = [[] for _ in range(len(r))] - for idx, reg in enumerate(r): + intra_region_order = [[] for _ in range(len(regions))] + for idx, reg in enumerate(regions): if len(region_lines[idx]) > 0: order = reading_order([x[1] for x in region_lines[idx]], text_direction) lsort = topsort(order) @@ -792,13 +789,13 @@ def polygonal_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tu lsort = topsort(order) sidz = sorted(indizes.keys()) lsort = [sidz[i] for i in lsort] - ordered_lines = [] + ordered_idxs = [] for i in lsort: if indizes[i][0] == 'line': - ordered_lines.append(indizes[i][1]) + ordered_idxs.append(i) else: - ordered_lines.extend(lines[x] for x in intra_region_order[indizes[i][1]]) - return ordered_lines + ordered_idxs.extend(intra_region_order[indizes[i][1]]) + return ordered_idxs def is_in_region(line, region) -> bool: @@ -817,6 +814,97 @@ def is_in_region(line, region) -> bool: return region.contains(l_obj) +def neural_reading_order(lines: Sequence[Dict], + text_direction: str = 'lr', + regions: Optional[Sequence[geom.Polygon]] = None, + im_size: Tuple[int, int] = None, + model: 'TorchVGSLModel' = None, + class_mapping: Dict[str, int] = None) -> Sequence[int]: + """ + Given a list of baselines and regions, calculates the correct reading order + and applies it to the input. + + Args: + lines: List of tuples containing the baseline and its polygonization. + model: torch Module for + + Returns: + The indices of the ordered input. + """ + lines = [(line['tags']['type'], line['baseline'], line['boundary']) for line in lines] + # construct all possible pairs + h, w = im_size + features = [] + for i in lines: + for j in lines: + if i == j and len(lines) != 1: + continue + num_classes = len(class_mapping) + 1 + cl_i = torch.zeros(num_classes, dtype=torch.float) + cl_j = torch.zeros(num_classes, dtype=torch.float) + cl_i[class_mapping.get(i[0], 0)] = 1 + cl_j[class_mapping.get(j[0], 0)] = 1 + line_coords_i = np.array(i[1]) / (w, h) + line_center_i = np.mean(line_coords_i, axis=0) + line_coords_j = np.array(j[1]) / (w, h) + line_center_j = np.mean(line_coords_j, axis=0) + features.append(torch.cat((cl_i, + torch.tensor(line_center_i, dtype=torch.float), # lin + torch.tensor(line_coords_i[0, :], dtype=torch.float), + torch.tensor(line_coords_i[-1, :], dtype=torch.float), + cl_j, + torch.tensor(line_center_j, dtype=torch.float), # lin + torch.tensor(line_coords_j[0, :], dtype=torch.float), + torch.tensor(line_coords_j[-1, :], dtype=torch.float)))) + features = torch.stack(features) + output = F.sigmoid(model(features)) + + order = torch.zeros((len(lines), len(lines))) + idx = 0 + for i in range(len(lines)): + for j in range(len(lines)): + if i == j and len(lines) != 1: + continue + order[i, j] = output[idx] + idx += 1 + # decode order relation matrix + path = _greedy_order_decoder(order) + return path + + +def _greedy_order_decoder(P): + """ + A greedy decoder of order-relation matrix. For each position in the + reading order we select the most probable one, then move to the next + position. Most probable for position: + + .. math:: + z^{\\star}_t = \\argmax_{(s,\\nu) \\ni z^{\\star}} + \\prod_{(s',\\nu') \\in z^\\star}{\\tilde{P}(Y=1\\mid s',s)} + \\times \\prod_{\\substack{(s'',\\nu'') \\ni z^\\star\\ + s'' \\ne s}}{\\tilde{P}(r=0\\mid s'',s)}, 1\\le t \\le n + """ + A = P + torch.finfo(torch.float).eps + N = P.shape[0] + A = (A + (1-A).T)/2 + for i in range(A.shape[0]): + A[i, i] = torch.finfo(torch.float).eps + best_path = [] + # use log(p(R\mid s',s)) to shift multiplication to sum + lP = torch.log(A) + for i in range(N): + lP[i, i] = 0 + for t in range(N): + for i in range(N): + idx = torch.argmax(lP.sum(axis=1)) + if idx not in best_path: + best_path.append(idx) + lP[idx, :] = lP[:, idx] + lP[:, idx] = 0 + break + return torch.tensor(best_path) + + def scale_regions(regions: Sequence[Tuple[List[int], List[int]]], scale: Union[float, Tuple[float, float]]) -> Sequence[Tuple[List, List]]: """ @@ -944,29 +1032,20 @@ def compute_polygon_section(baseline: Sequence[Tuple[int, int]], return tuple(o) -def extract_polygons(im: Image.Image, bounds: Dict[str, Any]) -> Image.Image: +def extract_polygons(im: Image.Image, bounds: 'kraken.containers.Segmentation') -> Image.Image: """ Yields the subimages of image im defined in the list of bounding polygons with baselines preserving order. Args: im: Input image - bounds: A list of dicts in baseline:: - - {'type': 'baselines', - 'lines': [{'baseline': [[x_0, y_0], ... [x_n, y_n]], - 'boundary': [[x_0, y_0], ... [x_n, y_n]]}, - ....] - } - - or bounding box format:: - - {'boxes': [[x_0, y_0, x_1, y_1], ...], 'text_direction': 'horizontal-lr'} + bounds: A Segmentation class containing a boundig box or baseline + segmentation. Yields: The extracted subimage """ - if 'type' in bounds and bounds['type'] == 'baselines': + if bounds.type == 'baselines': # select proper interpolation scheme depending on shape if im.mode == '1': order = 0 @@ -975,11 +1054,11 @@ def extract_polygons(im: Image.Image, bounds: Dict[str, Any]) -> Image.Image: order = 1 im = np.array(im) - for line in bounds['lines']: - if line['boundary'] is None: + for line in bounds.lines: + if line.boundary is None: raise KrakenInputException('No boundary given for line') - pl = np.array(line['boundary']) - baseline = np.array(line['baseline']) + pl = np.array(line.boundary) + baseline = np.array(line.baseline) c_min, c_max = int(pl[:, 0].min()), int(pl[:, 0].max()) r_min, r_max = int(pl[:, 1].min()), int(pl[:, 1].max()) @@ -1064,11 +1143,11 @@ def extract_polygons(im: Image.Image, bounds: Dict[str, Any]) -> Image.Image: i = Image.fromarray(o.astype('uint8')) yield i.crop(i.getbbox()), line else: - if bounds['text_direction'].startswith('vertical'): + if bounds.text_direction.startswith('vertical'): angle = 90 else: angle = 0 - for box in bounds['boxes']: + for box in bounds.lines: if isinstance(box, tuple): box = list(box) if (box < [0, 0, 0, 0] or box[::2] >= [im.size[0], im.size[0]] or diff --git a/kraken/lib/train.py b/kraken/lib/train.py index be9b70a31..f94568f78 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -33,7 +33,6 @@ from pytorch_lightning.callbacks import Callback, EarlyStopping, BaseFinetuning, LearningRateMonitor from kraken.lib import models, vgsl, default_specs, progress -from kraken.lib.xml import preparse_xml_data from kraken.lib.util import make_printable from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import (ArrowIPCRecognitionDataset, BaselineSet, @@ -246,7 +245,8 @@ def __init__(self, if hyper_params: hyper_params_.update(hyper_params) - self.save_hyperparameters(hyper_params_) + self.hyper_params = hyper_params_ + self.save_hyperparameters() self.reorder = reorder self.append = append @@ -270,10 +270,10 @@ def __init__(self, valid_norm = True if format_type in ['xml', 'page', 'alto']: logger.info(f'Parsing {len(training_data)} XML files for training data') - training_data = preparse_xml_data(training_data, format_type, repolygonize) + training_data = [{'page': XMLPage(file, format_type).to_container()} for file in training_data] if evaluation_data: logger.info(f'Parsing {len(evaluation_data)} XML files for validation data') - evaluation_data = preparse_xml_data(evaluation_data, format_type, repolygonize) + evaluation_data = [{'page': XMLPage(file, format_type).to_container()} for file in evaluation_data] if binary_dataset_split: logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') binary_dataset_split = False @@ -304,9 +304,9 @@ def __init__(self, logger.info(f'Got {len(evaluation_data)} line strip images for validation data') evaluation_data = [{'image': im} for im in evaluation_data] valid_norm = True - # format_type is None. Determine training type from length of training data entry + # format_type is None. Determine training type from container class types elif not format_type: - if len(training_data[0]) >= 4: + if training_data[0].type == 'baselines': DatasetClass = PolygonGTDataset valid_norm = False else: @@ -318,6 +318,21 @@ def __init__(self, if binary_dataset_split: logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') binary_dataset_split = False + samples = [] + for sample in training_data: + if isinstance(sample, Segmentation): + samples.append({'page': sample}) + else: + samples.append({'line': sample}) + training_data = samples + if evaluation_data: + samples = [] + for sample in evaluation_data: + if isinstance(sample, Segmentation): + samples.append({'page': sample}) + else: + samples.append({'line': sample}) + evaluation_data = samples else: raise ValueError(f'format_type {format_type} not in [alto, page, xml, path, binary].') @@ -415,28 +430,22 @@ def _build_dataset(self, DatasetClass, training_data, **kwargs): - dataset = DatasetClass(normalization=self.hparams.normalization, - whitespace_normalization=self.hparams.normalize_whitespace, + dataset = DatasetClass(normalization=self.hparams.hyper_params['normalization'], + whitespace_normalization=self.hparams.hyper_params['normalize_whitespace'], reorder=self.reorder, im_transforms=self.transforms, - augmentation=self.hparams.augment, + augmentation=self.hparams.hyper_params['augment'], **kwargs) - if (self.num_workers and self.num_workers > 1) and self.format_type != 'binary': - with Pool(processes=self.num_workers) as pool: - for im in pool.imap_unordered(partial(_star_fun, dataset.parse), training_data, 5): - logger.debug(f'Adding sample {im} to training set') - if im: - dataset.add(**im) - else: - for im in training_data: - try: - dataset.add(**im) - except KrakenInputException as e: - logger.warning(str(e)) - if self.format_type == 'binary' and self.hparams.normalization: - logger.debug('Rebuilding dataset using unicode normalization') - dataset.rebuild_alphabet() + for sample in training_data: + try: + dataset.add(**sample) + except KrakenInputException as e: + logger.warning(str(e)) + if self.format_type == 'binary' and self.hparams.hyper_params['normalization']: + logger.debug('Rebuilding dataset using unicode normalization') + dataset.rebuild_alphabet() + return dataset def forward(self, x, seq_lens=None): @@ -591,7 +600,7 @@ def setup(self, stage: Optional[str] = None): if self.format_type != 'path' and self.nn.seg_type == 'bbox': logger.warning('Neural network has been trained on bounding box image information but training set is polygonal.') - self.nn.hyper_params = self.hparams + self.nn.hyper_params = self.hparams.hyper_params self.nn.model_type = 'recognition' if not self.nn.seg_type: @@ -605,7 +614,7 @@ def setup(self, stage: Optional[str] = None): def train_dataloader(self): return DataLoader(self.train_set, - batch_size=self.hparams.batch_size, + batch_size=self.hparams.hyper_params['batch_size'], num_workers=self.num_workers, pin_memory=True, shuffle=True, @@ -614,7 +623,7 @@ def train_dataloader(self): def val_dataloader(self): return DataLoader(self.val_set, shuffle=False, - batch_size=self.hparams.batch_size, + batch_size=self.hparams.hyper_params['batch_size'], num_workers=self.num_workers, pin_memory=True, collate_fn=collate_sequences, @@ -622,11 +631,12 @@ def val_dataloader(self): def configure_callbacks(self): callbacks = [] - if self.hparams.quit == 'early': + if self.hparams.hyper_params['quit'] == 'early': callbacks.append(EarlyStopping(monitor='val_accuracy', mode='max', - patience=self.hparams.lag, + patience=self.hparams.hyper_params['lag'], stopping_threshold=1.0)) + return callbacks # configuration of optimizers and learning rate schedulers @@ -636,7 +646,7 @@ def configure_callbacks(self): # batch-wise learning rate warmup. In lr_scheduler_step() calls to the # scheduler are then only performed at the end of the epoch. def configure_optimizers(self): - return _configure_optimizer_and_lr_scheduler(self.hparams, + return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, self.nn.nn.parameters(), len_train_set=len(self.train_set), loss_tracking_mode='max') @@ -648,13 +658,13 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # linear warmup between 0 and the initial learning rate `lrate` in `warmup` # steps. - if self.hparams.warmup and self.trainer.global_step < self.hparams.warmup: - lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.warmup) + if self.hparams.hyper_params['warmup'] and self.trainer.global_step < self.hparams.hyper_params['warmup']: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params['warmup']) for pg in optimizer.param_groups: - pg["lr"] = lr_scale * self.hparams.lrate + pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] - def lr_scheduler_step(self, scheduler, metric): - if not self.hparams.warmup or self.trainer.global_step >= self.hparams.warmup: + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: # step OneCycleLR each batch if not in warmup phase if isinstance(scheduler, lr_scheduler.OneCycleLR): scheduler.step() @@ -759,7 +769,8 @@ def __init__(self, hyper_params_.update(hyper_params) validate_hyper_parameters(hyper_params_) - self.save_hyperparameters(hyper_params_) + self.hyper_params = hyper_params_ + self.save_hyperparameters() if not training_data: raise ValueError('No training data provided. Please add some.') @@ -768,7 +779,7 @@ def __init__(self, height, width, channels, - self.hparams.padding, + self.hparams.hyper_params.padding, valid_norm=False, force_binarization=force_binarization) @@ -795,10 +806,10 @@ def __init__(self, merge_baselines = None train_set = BaselineSet(training_data, - line_width=self.hparams.line_width, + line_width=self.hparams.hyper_params.line_width, im_transforms=transforms, mode=format_type, - augmentation=self.hparams.augment, + augmentation=self.hparams.hyper_params.augment, valid_baselines=valid_baselines, merge_baselines=merge_baselines, valid_regions=valid_regions, @@ -810,7 +821,7 @@ def __init__(self, if evaluation_data: val_set = BaselineSet(evaluation_data, - line_width=self.hparams.line_width, + line_width=self.hparams.hyper_params.line_width, im_transforms=transforms, mode=format_type, augmentation=False, @@ -886,7 +897,7 @@ def on_validation_epoch_end(self): self.log('val_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.val_px_accuracy.reset() self.val_mean_accuracy.reset() @@ -1035,10 +1046,10 @@ def val_dataloader(self): def configure_callbacks(self): callbacks = [] - if self.hparams.quit == 'early': + if self.hparams.hyper_params['quit'] == 'early': callbacks.append(EarlyStopping(monitor='val_mean_iu', mode='max', - patience=self.hparams.lag, + patience=self.hparams.hyper_params['lag'], stopping_threshold=1.0)) return callbacks @@ -1050,7 +1061,7 @@ def configure_callbacks(self): # batch-wise learning rate warmup. In lr_scheduler_step() calls to the # scheduler are then only performed at the end of the epoch. def configure_optimizers(self): - return _configure_optimizer_and_lr_scheduler(self.hparams, + return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, self.nn.nn.parameters(), len_train_set=len(self.train_set), loss_tracking_mode='max') @@ -1061,13 +1072,13 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # linear warmup between 0 and the initial learning rate `lrate` in `warmup` # steps. - if self.hparams.warmup and self.trainer.global_step < self.hparams.warmup: - lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.warmup) + if self.hparams.hyper_params['warmup'] and self.trainer.global_step < self.hparams.hyper_params['warmup']: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params['warmup']) for pg in optimizer.param_groups: - pg["lr"] = lr_scale * self.hparams.lrate + pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] - def lr_scheduler_step(self, scheduler, metric): - if not self.hparams.warmup or self.trainer.global_step >= self.hparams.warmup: + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: # step OneCycleLR each batch if not in warmup phase if isinstance(scheduler, lr_scheduler.OneCycleLR): scheduler.step() @@ -1077,51 +1088,63 @@ def lr_scheduler_step(self, scheduler, metric): def _configure_optimizer_and_lr_scheduler(hparams, params, len_train_set=None, loss_tracking_mode='max'): + optimizer = hparams.get("optimizer") + lrate = hparams.get("lrate") + momentum = hparams.get("momentum") + weight_decay = hparams.get("weight_decay") + schedule = hparams.get("schedule") + gamma = hparams.get("gamma") + step_size = hparams.get("step_size") + rop_factor = hparams.get("rop_factor") + rop_patience = hparams.get("rop_patience") + epochs = hparams.get("epochs") + completed_epochs = hparams.get("completed_epochs") + # XXX: Warmup is not configured here because it needs to be manually done in optimizer_step() - logger.debug(f'Constructing {hparams.optimizer} optimizer (lr: {hparams.lrate}, momentum: {hparams.momentum})') - if hparams.optimizer == 'Adam': - optim = torch.optim.Adam(params, lr=hparams.lrate, weight_decay=hparams.weight_decay) + logger.debug(f'Constructing {optimizer} optimizer (lr: {lrate}, momentum: {momentum})') + if optimizer == 'Adam': + optim = torch.optim.Adam(params, lr=lrate, weight_decay=weight_decay) else: - optim = getattr(torch.optim, hparams.optimizer)(params, - lr=hparams.lrate, - momentum=hparams.momentum, - weight_decay=hparams.weight_decay) + optim = getattr(torch.optim, optimizer)(params, + lr=lrate, + momentum=momentum, + weight_decay=weight_decay) lr_sched = {} - if hparams.schedule == 'exponential': - lr_sched = {'scheduler': lr_scheduler.ExponentialLR(optim, hparams.gamma, last_epoch=hparams.completed_epochs-1), + if schedule == 'exponential': + lr_sched = {'scheduler': lr_scheduler.ExponentialLR(optim, gamma, last_epoch=completed_epochs-1), 'interval': 'step'} - elif hparams.schedule == 'cosine': - lr_sched = {'scheduler': lr_scheduler.CosineAnnealingLR(optim, hparams.gamma, last_epoch=hparams.completed_epochs-1), + elif schedule == 'cosine': + lr_sched = {'scheduler': lr_scheduler.CosineAnnealingLR(optim, gamma, last_epoch=completed_epochs-1), 'interval': 'step'} - elif hparams.schedule == 'step': - lr_sched = {'scheduler': lr_scheduler.StepLR(optim, hparams.step_size, hparams.gamma, last_epoch=hparams.completed_epochs-1), + elif schedule == 'step': + lr_sched = {'scheduler': lr_scheduler.StepLR(optim, step_size, gamma, last_epoch=completed_epochs-1), 'interval': 'step'} - elif hparams.schedule == 'reduceonplateau': + elif schedule == 'reduceonplateau': lr_sched = {'scheduler': lr_scheduler.ReduceLROnPlateau(optim, mode=loss_tracking_mode, - factor=hparams.rop_factor, - patience=hparams.rop_patience), + factor=rop_factor, + patience=rop_patience), 'interval': 'step'} - elif hparams.schedule == '1cycle': - if hparams.epochs <= 0: + elif schedule == '1cycle': + if epochs <= 0: raise ValueError('1cycle learning rate scheduler selected but ' 'number of epochs is less than 0 ' - f'({hparams.epochs}).') - last_epoch = hparams.completed_epochs*len_train_set if hparams.completed_epochs else -1 + f'({epochs}).') + last_epoch = completed_epochs*len_train_set if completed_epochs else -1 lr_sched = {'scheduler': lr_scheduler.OneCycleLR(optim, - max_lr=hparams.lrate, - epochs=hparams.epochs, + max_lr=lrate, + epochs=epochs, steps_per_epoch=len_train_set, last_epoch=last_epoch), 'interval': 'step'} - elif hparams.schedule != 'constant': - raise ValueError(f'Unsupported learning rate scheduler {hparams.schedule}.') + elif schedule != 'constant': + raise ValueError(f'Unsupported learning rate scheduler {schedule}.') ret = {'optimizer': optim} if lr_sched: ret['lr_scheduler'] = lr_sched - if hparams.schedule == 'reduceonplateau': + if schedule == 'reduceonplateau': lr_sched['monitor'] = 'val_metric' lr_sched['strict'] = False lr_sched['reduce_on_plateau'] = True diff --git a/kraken/lib/vgsl.py b/kraken/lib/vgsl.py index 9f745ceb4..ad762da77 100644 --- a/kraken/lib/vgsl.py +++ b/kraken/lib/vgsl.py @@ -137,7 +137,7 @@ def __init__(self, spec: str) -> None: self.build_dropout, self.build_maxpool, self.build_conv, self.build_output, self.build_reshape, self.build_wav2vec2, self.build_groupnorm, self.build_series, - self.build_parallel] + self.build_parallel, self.build_ro] self.codec = None # type: Optional[PytorchCodec] self.criterion = None # type: Any self.nn = layers.MultiParamSequential() @@ -577,6 +577,25 @@ def build_wav2vec2(self, f'{mask_prob}, negative samples {num_negatives}') return fn.get_shape(input), [VGSLBlock(blocks[idx], m.group('type'), m.group('name'), self.idx)], fn + def build_ro(self, + input: Tuple[int, int, int, int], + blocks: List[str], + idx: int) -> Union[Tuple[None, None, None], Tuple[Tuple[int, int, int, int], str, Callable]]: + """ + Builds a RO determination layer. + """ + pattern = re.compile(r'(?PRO)(?P{\w+})(?P\d+),(?P\d+)') + m = pattern.match(blocks[idx]) + if not m: + return None, None, None + feature_size = int(m.group('feature_size')) + hidden_size = int(m.group('hidden_size')) + from kraken.lib import ro + fn = ro.layers.MLP(feature_size, hidden_size) + self.idx += 1 + logger.debug(f'{self.idx}\t\tro\tfeatures {feature_size}, hidden_size {hidden_size}') + return fn.get_shape(input), [VGSLBlock(blocks[idx], m.group('type'), m.group('name'), self.idx)], fn + def build_conv(self, input: Tuple[int, int, int, int], blocks: List[str], diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index 608cd7dc5..3f60fa7ee 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -17,21 +17,24 @@ """ import re import logging + +from os import PathLike from pathlib import Path from itertools import groupby from lxml import etree from PIL import Image -from typing import Union, Dict, Any, Sequence, Tuple +from typing import Union, Dict, Any, Sequence, Tuple, Literal, Optional, List -from os import PathLike from collections import defaultdict +from kraken.containers import Segmentation, BaselineLine, Region from kraken.lib.segmentation import calculate_polygonal_environment from kraken.lib.exceptions import KrakenInputException logger = logging.getLogger(__name__) -__all__ = ['parse_xml', 'parse_page', 'parse_alto', 'preparse_xml_data'] +__all__ = ['XMLPage'] + # fallback mapping between PAGE region types and tags page_regions = {'TextRegion': 'text', @@ -52,289 +55,489 @@ # same for ALTO alto_regions = {'TextBlock': 'text', - 'IllustrationType': 'illustration', - 'GraphicalElementType': 'graphic', + 'Illustration': 'illustration', + 'GraphicalElement': 'graphic', 'ComposedBlock': 'composed'} -def preparse_xml_data(filenames: Sequence[Union[str, PathLike]], - format_type: str = 'xml', - repolygonize: bool = False) -> Dict[str, Any]: - """ - Loads training data from a set of xml files. - - Extracts line information from Page/ALTO xml files for training of - recognition models. - - Args: - filenames: List of XML files. - format_type: Either `page`, `alto` or `xml` for autodetermination. - repolygonize: (Re-)calculates polygon information using the kraken - algorithm. - - Returns: - A list of dicts {'text': text, 'baseline': [[x0, y0], ...], 'boundary': - [[x0, y0], ...], 'image': PIL.Image}. - """ - training_pairs = [] - if format_type == 'xml': - parse_fn = parse_xml - elif format_type == 'alto': - parse_fn = parse_alto - elif format_type == 'page': - parse_fn = parse_page - else: - raise ValueError(f'invalid format {format_type} for preparse_xml_data') - - for fn in filenames: - try: - data = parse_fn(fn) - except KrakenInputException as e: - logger.warning(e) - continue - try: - with open(data['image'], 'rb') as fp: - Image.open(fp) - except FileNotFoundError as e: - logger.warning(f'Could not open file {e.filename} in {fn}') - continue - if repolygonize: - logger.info('repolygonizing {} lines in {}'.format(len(data['lines']), data['image'])) - data['lines'] = _repolygonize(data['image'], data['lines']) - for line in data['lines']: - training_pairs.append({'image': data['image'], **line}) - return training_pairs - - -def _repolygonize(im: Image.Image, lines: Sequence[Dict[str, Any]]): - """ - Helper function taking an output of the lib.xml parse_* functions and - recalculating the contained polygonization. - - Args: - im (Image.Image): Input image - lines (list): List of dicts [{'boundary': [[x0, y0], ...], 'baseline': [[x0, y0], ...], 'text': 'abcvsd'}, {...] - - Returns: - A data structure `lines` with a changed polygonization. - """ - im = Image.open(im).convert('L') - polygons = calculate_polygonal_environment(im, [x['baseline'] for x in lines]) - return [{'boundary': polygon, - 'baseline': orig['baseline'], - 'text': orig['text'], - 'script': orig['script']} for orig, polygon in zip(lines, polygons)] - - -def parse_xml(filename: Union[str, PathLike]) -> Dict[str, Any]: - """ - Parses either a PageXML or ALTO file with autodetermination of the file - format. - - Args: - filename: path to an XML file. - - Returns: - A dict:: - - {'image': impath, - 'lines': [{'boundary': [[x0, y0], ...], - 'baseline': [[x0, y0], ...], - 'text': apdjfqpf', - 'tags': {'type': 'default', ...}}, - ... - {...}], - 'regions': {'region_type_0': [[[x0, y0], ...], ...], ...}} - """ - with open(filename, 'rb') as fp: - try: - doc = etree.parse(fp) - except etree.XMLSyntaxError as e: - raise KrakenInputException(f'Parsing {filename} failed: {e}') - if doc.getroot().tag.endswith('alto'): - return parse_alto(filename) - elif doc.getroot().tag.endswith('PcGts'): - return parse_page(filename) - else: - raise KrakenInputException(f'Unknown XML format in {filename}') - - -def parse_page(filename: Union[str, PathLike]) -> Dict[str, Any]: - """ - Parses a PageXML file, returns the baselines defined in it, and loads the - referenced image. - - Args: - filename: path to a PageXML file. - - Returns: - A dict:: - - {'image': impath, - 'lines': [{'boundary': [[x0, y0], ...], - 'baseline': [[x0, y0], ...], - 'text': apdjfqpf', - 'tags': {'type': 'default', ...}}, - ... - {...}], - 'regions': {'region_type_0': [[[x0, y0], ...], ...], ...}} - """ - def _parse_page_custom(s): - o = {} - s = s.strip() - l_chunks = [l_chunk for l_chunk in s.split('}') if l_chunk.strip()] - if l_chunks: - for chunk in l_chunks: - tag, vals = chunk.split('{') - tag_vals = {} - vals = [val.strip() for val in vals.split(';') if val.strip()] - for val in vals: - key, *val = val.split(':') - tag_vals[key] = ":".join(val) - o[tag.strip()] = tag_vals - return o - - def _parse_coords(coords): - points = [x for x in coords.split(' ')] - points = [int(c) for point in points for c in point.split(',')] - pts = zip(points[::2], points[1::2]) - return [k for k, g in groupby(pts)] +class XMLPage(object): + + type: Literal['baselines', 'bbox'] = 'baselines' + base_dir: Optional[Literal['L', 'R']] = None + imagename: PathLike = None + _orders: Dict[str, Dict[str, Any]] = None + has_tags: bool = False + _tag_set: Optional[Dict] = None + has_splits: bool = False + _split_set: Optional[List] = None + + def __init__(self, + filename: Union[str, PathLike], + filetype: Literal['xml', 'alto', 'page'] = 'xml'): + super().__init__() + self.filename = Path(filename) + self.filetype = filetype + + self._regions = {} + self._lines = {} + self._orders = {'line_implicit': {'order': [], 'is_total': True, 'description': 'Implicit line order derived from element sequence'}, + 'region_implicit': {'order': [], 'is_total': True, 'description': 'Implicit region order derived from element sequence'}} + + if filetype == 'xml': + self._parse_xml() + elif filetype == 'alto': + self._parse_alto() + elif filetype == 'page': + self._parse_page() + + def _parse_xml(self): + with open(self.filename, 'rb') as fp: + try: + doc = etree.parse(fp) + except etree.XMLSyntaxError as e: + raise ValueError(f'Parsing {self.filename} failed: {e}') + if doc.getroot().tag.endswith('alto'): + return self._parse_alto() + elif doc.getroot().tag.endswith('PcGts'): + return self._parse_page() + else: + raise ValueError(f'Unknown XML format in {self.filename}') - with open(filename, 'rb') as fp: - base_dir = Path(filename).parent - try: - doc = etree.parse(fp) - except etree.XMLSyntaxError as e: - raise KrakenInputException('Parsing {} failed: {}'.format(filename, e)) - image = doc.find('.//{*}Page') - if image is None or image.get('imageFilename') is None: - raise KrakenInputException('No valid image filename found in PageXML file {}'.format(filename)) - try: - base_direction = {'left-to-right': 'L', - 'right-to-left': 'R', - 'top-to-bottom': 'L', - 'bottom-to-top': 'R', - None: None}[image.get('readingDirection')] - except KeyError: - logger.warning(f'Invalid value {image.get("readingDirection")} encountered in page-level reading direction.') - base_direction = None - lines = doc.findall('.//{*}TextLine') - data = {'image': base_dir.joinpath(image.get('imageFilename')), - 'lines': [], - 'type': 'baselines', - 'base_dir': base_direction, - 'regions': {}} - # find all image regions - regions = [] - for x in page_regions.keys(): - regions.extend(doc.findall('.//{{*}}{}'.format(x))) - # parse region type and coords - region_data = defaultdict(list) - for region in regions: - coords = region.find('{*}Coords') - if coords is not None and not coords.get('points').isspace() and len(coords.get('points')): - try: - coords = _parse_coords(coords.get('points')) - except Exception: - logger.warning('Region {} without coordinates'.format(region.get('id'))) - continue - else: - logger.warning('Region {} without coordinates'.format(region.get('id'))) - continue - rtype = region.get('type') - # parse transkribus-style custom field if possible - custom_str = region.get('custom') - if not rtype and custom_str: - cs = _parse_page_custom(custom_str) - if 'structure' in cs and 'type' in cs['structure']: - rtype = cs['structure']['type'] - # fall back to default region type if nothing is given - if not rtype: - rtype = page_regions[region.tag.split('}')[-1]] - region_data[rtype].append(coords) - - data['regions'] = region_data - - # parse line information - tag_set = set(('default',)) - for line in lines: - pol = line.find('./{*}Coords') - boundary = None - if pol is not None and not pol.get('points').isspace() and len(pol.get('points')): - try: - boundary = _parse_coords(pol.get('points')) - except Exception: - logger.info('TextLine {} without polygon'.format(line.get('id'))) + def _parse_alto(self): + with open(self.filename, 'rb') as fp: + base_directory = self.filename.parent + try: + doc = etree.parse(fp) + except etree.XMLSyntaxError as e: + raise ValueError('Parsing {} failed: {}'.format(self.filename, e)) + image = doc.find('.//{*}fileName') + if image is None or not image.text: + raise ValueError('No valid image filename found in ALTO file {self.filename}') + self.imagename = base_directory.joinpath(image.text) + + # find all image regions in order + regions = [] + for el in doc.iterfind('./{*}Layout/{*}Page/{*}PrintSpace/{*}*'): + for block_type in alto_regions.keys(): + if el.tag.endswith(block_type): + regions.append(el) + # find overall dimensions to filter out dummy TextBlocks + ps = doc.find('./{*}Layout/{*}Page/{*}PrintSpace') + x_min = int(float(ps.get('HPOS'))) + y_min = int(float(ps.get('VPOS'))) + width = int(float(ps.get('WIDTH'))) + height = int(float(ps.get('HEIGHT'))) + + # parse tagrefs + cls_map = {} + tags = doc.find('.//{*}Tags') + if tags is not None: + for x in ['StructureTag', 'LayoutTag', 'OtherTag']: + for tag in tags.findall('./{{*}}{}'.format(x)): + cls_map[tag.get('ID')] = (x[:-3].lower(), tag.get('LABEL')) + + self._tag_set = set(('default',)) + + # parse region type and coords + region_data = defaultdict(list) + for region in regions: + # try to find shape object + coords = region.find('./{*}Shape/{*}Polygon') + if coords is not None: + boundary = self._parse_alto_pointstype(coords.get('POINTS')) + elif (region.get('HPOS') is not None and region.get('VPOS') is not None and + region.get('WIDTH') is not None and region.get('HEIGHT') is not None): + # use rectangular definition + x_min = int(float(region.get('HPOS'))) + y_min = int(float(region.get('VPOS'))) + width = int(float(region.get('WIDTH'))) + height = int(float(region.get('HEIGHT'))) + boundary = [(x_min, y_min), + (x_min, y_min + height), + (x_min + width, y_min + height), + (x_min + width, y_min)] + rtype = region.get('TYPE') + # fall back to default region type if nothing is given + tagrefs = region.get('TAGREFS') + if tagrefs is not None and rtype is None: + for tagref in tagrefs.split(): + ttype, rtype = cls_map.get(tagref, (None, None)) + if rtype is not None and ttype: + break + if rtype is None: + rtype = alto_regions[region.tag.split('}')[-1]] + region_id = region.get('ID') + region_data[rtype].append(Region(id=region_id, boundary=coords, tags={'type': rtype})) + # register implicit reading order + self._orders['region_implicit']['order'].append(region_id) + + # parse lines in region + for line in region.iterfind('./{*}TextLine'): + if line.get('BASELINE') is None: + logger.info('TextLine {} without baseline'.format(line.get('ID'))) + continue + pol = line.find('./{*}Shape/{*}Polygon') + boundary = None + if pol is not None: + try: + boundary = self._parse_alto_pointstype(pol.get('POINTS')) + except ValueError: + logger.info('TextLine {} without polygon'.format(line.get('ID'))) + else: + logger.info('TextLine {} without polygon'.format(line.get('ID'))) + + baseline = None + try: + baseline = self._parse_alto_pointstype(line.get('BASELINE')) + except ValueError: + logger.info('TextLine {} without baseline'.format(line.get('ID'))) + + text = '' + for el in line.xpath(".//*[local-name() = 'String'] | .//*[local-name() = 'SP']"): + text += el.get('CONTENT') if el.get('CONTENT') else ' ' + # find line type + tags = {'type': 'default'} + split_type = None + tagrefs = line.get('TAGREFS') + if tagrefs is not None: + for tagref in tagrefs.split(): + ttype, ltype = cls_map.get(tagref, (None, None)) + if ltype is not None: + self._tag_set.add(ltype) + if ttype == 'other': + tags['type'] = ltype + else: + tags[ttype] = ltype + if ltype in ['train', 'validation', 'test']: + split_type = ltype + self._lines[line.get('ID')] = BaselineLine(id=line.get('ID'), + baseline=baseline, + boundary=boundary, + text=text, + tags=tags, + split=split_type, + regions=[region_id]) + # register implicit reading order + self._orders['line_implicit']['order'].append(line.get('ID')) + + self._regions = region_data + + if len(self._tag_set) > 1: + self.has_tags = True else: - logger.info('TextLine {} without polygon'.format(line.get('id'))) - base = line.find('./{*}Baseline') - baseline = None - if base is not None and not base.get('points').isspace() and len(base.get('points')): - try: - baseline = _parse_coords(base.get('points')) - except Exception: - logger.info('TextLine {} without baseline'.format(line.get('id'))) + self.has_tags = False + + # parse explicit reading orders if they exist + ro_el = doc.find('.//{*}ReadingOrder') + if ro_el is not None: + reading_orders = ro_el.getchildren() + # UnorderedGroup at top-level => treated as multiple reading orders + if len(reading_orders) == 1 and reading_orders[0].tag.endswith('UnorderedGroup'): + reading_orders = reading_orders[0].getchildren() + else: + reading_orders = [reading_orders] + + def _parse_group(el): + nonlocal is_valid + + _ro = [] + if el.tag.endswith('UnorderedGroup'): + _ro = [_parse_group(x) for x in el.iterchildren()] + is_total = False + elif el.tag.endswith('OrderedGroup'): + _ro.extend(_parse_group(x) for x in el.iterchildren()) + else: + ref = el.get('REF') + res = doc.find(f'.//{{*}}*[@ID="{ref}"]') + if res is None: + logger.warning(f'Nonexistent element with ID {ref} in reading order. Skipping RO {ro.get("ID")}.') + is_valid = False + return _ro + tag = res.tag.split('}')[-1] + if tag not in alto_regions.keys() and tag != 'TextLine': + logger.warning(f'Sub-line element with ID {ref} in reading order. Skipping RO {ro.get("ID")}.') + is_valid = False + return _ro + return ref + return _ro + + for ro in reading_orders: + is_total = True + is_valid = True + joint_order = _parse_group(ro) + if is_valid: + tag = ro.get('TAGREFS') + self._orders[ro.get('ID')] = {'order': joint_order, + 'is_total': is_total, + 'description': cls_map[tag] if tag and tag in cls_map else ''} + self.filetype = 'alto' + + def _parse_page(self): + with open(self.filename, 'rb') as fp: + base_directory = self.filename.parent + + try: + doc = etree.parse(fp) + except etree.XMLSyntaxError as e: + raise ValueError(f'Parsing {self.filename} failed: {e}') + image = doc.find('.//{*}Page') + if image is None or image.get('imageFilename') is None: + raise ValueError(f'No valid image filename found in PageXML file {self.filename}') + try: + self.base_dir = {'left-to-right': 'L', + 'right-to-left': 'R', + 'top-to-bottom': 'L', + 'bottom-to-top': 'R', + None: None}[image.get('readingDirection')] + except KeyError: + logger.warning(f'Invalid value {image.get("readingDirection")} encountered in page-level reading direction.') + self.imagename = base_directory.joinpath(image.get('imageFilename')) + # find all image regions + regions = [reg for reg in image.iterfind('./{*}*')] + # parse region type and coords + region_data = defaultdict(list) + tr_region_order = [] + + self._tag_set = set(('default',)) + tmp_transkribus_line_order = defaultdict(list) + valid_tr_lo = True + + for region in regions: + if not any([True if region.tag.endswith(k) else False for k in page_regions.keys()]): continue + coords = region.find('./{*}Coords') + if coords is not None and not coords.get('points').isspace() and len(coords.get('points')): + try: + coords = self._parse_page_coords(coords.get('points')) + except Exception: + logger.warning('Region {} without coordinates'.format(region.get('id'))) + coords = None + else: + logger.warning('Region {} without coordinates'.format(region.get('id'))) + coords = None + rtype = region.get('type') + # parse transkribus-style custom field if possible + custom_str = region.get('custom') + if custom_str: + cs = self._parse_page_custom(custom_str) + if not rtype and 'structure' in cs and 'type' in cs['structure']: + rtype = cs['structure']['type'] + # transkribus-style reading order + if 'readingOrder' in cs and 'index'in cs['readingOrder']: + tr_region_order.append((region.get('id'), int(cs['readingOrder']['index']))) + # fall back to default region type if nothing is given + if not rtype: + rtype = page_regions[region.tag.split('}')[-1]] + region_data[rtype].append(Region(id=region.get('id'), boundary=coords, tags={'type': rtype})) + # register implicit reading order + self._orders['region_implicit']['order'].append(region.get('id')) + + # parse line information + for line in region.iterfind('./{*}TextLine'): + pol = line.find('./{*}Coords') + boundary = None + if pol is not None and not pol.get('points').isspace() and len(pol.get('points')): + try: + boundary = self._parse_page_coords(pol.get('points')) + except Exception: + logger.info('TextLine {} without polygon'.format(line.get('id'))) + else: + logger.info('TextLine {} without polygon'.format(line.get('id'))) + base = line.find('./{*}Baseline') + baseline = None + if base is not None and not base.get('points').isspace() and len(base.get('points')): + try: + baseline = self._parse_page_coords(base.get('points')) + except Exception: + logger.info('TextLine {} without baseline'.format(line.get('id'))) + continue + else: + logger.info('TextLine {} without baseline'.format(line.get('id'))) + continue + text = '' + manual_transcription = line.find('./{*}TextEquiv') + if manual_transcription is not None: + transcription = manual_transcription + else: + transcription = line + for el in transcription.findall('.//{*}Unicode'): + if el.text: + text += el.text + # retrieve line tags if custom string is set and contains + tags = {'type': 'default'} + split_type = None + custom_str = line.get('custom') + if custom_str: + cs = self._parse_page_custom(custom_str) + if 'structure' in cs and 'type' in cs['structure']: + tags['type'] = cs['structure']['type'] + self._tag_set.add(tags['type']) + # retrieve data split if encoded in custom string. + if 'split' in cs and 'type' in cs['split'] and cs['split']['type'] in ['train', 'validation', 'test']: + split_type = cs['split']['type'] + tags['split'] = split_type + self._tag_set.add(split_type) + if 'readingOrder' in cs and 'index' in cs['readingOrder']: + # look up region index from parent + reg_cus = self._parse_page_custom(line.getparent().get('custom')) + if 'readingOrder' not in reg_cus or 'index' not in reg_cus['readingOrder']: + logger.warning('Incomplete `custom` attribute reading order found.') + valid_tr_lo = False + else: + tmp_transkribus_line_order[int(reg_cus['readingOrder']['index'])].append((int(cs['readingOrder']['index']), line.get('id'))) + + self._lines[line.get('id')] = BaselineLine(id=line.get('id'), + baseline=baseline, + boundary=boundary, + text=text, + tags=tags, + split=split_type, + regions=[region.get('id')]) + # register implicit reading order + self._orders['line_implicit']['order'].append(line.get('id')) + + # add transkribus-style region order + self._orders['region_transkribus'] = {'order': [x[0] for x in sorted(tr_region_order, key=lambda k: k[1])], + 'is_total': True if len(set(map(lambda x: x[0], tr_region_order))) == len(tr_region_order) else False, + 'description': 'Explicit region order from `custom` attribute'} + + self._regions = region_data + + if tmp_transkribus_line_order: + # sort by regions + tmp_reg_order = sorted(((k, v) for k, v in tmp_transkribus_line_order.items()), key=lambda k: k[0]) + # flatten + tr_line_order = [] + for _, lines in tmp_reg_order: + tr_line_order.extend([x[1] for x in sorted(lines, key=lambda k: k[0])]) + self._orders['line_transkribus'] = {'order': tr_line_order, + 'is_total': True, + 'description': 'Explicit line order from `custom` attribute'} + + # parse explicit reading orders if they exist + ro_el = doc.find('.//{*}ReadingOrder') + if ro_el is not None: + reading_orders = ro_el.getchildren() + # UnorderedGroup at top-level => treated as multiple reading orders + if len(reading_orders) == 1 and reading_orders[0].tag.endswith('UnorderedGroup'): + reading_orders = reading_orders.getchildren() + + def _parse_group(el): + + _ro = [] + if el.tag.endswith('UnorderedGroup'): + _ro = [_parse_group(x) for x in el.iterchildren()] + is_total = False + elif el.tag.endswith('OrderedGroup'): + _ro.extend(_parse_group(x) for x in el.iterchildren()) + else: + return el.get('regionRef') + return _ro + + for ro in reading_orders: + is_total = True + self._orders[ro.get('id')] = {'order': _parse_group(ro), + 'is_total': is_total, + 'description': ro.get('caption') if ro.get('caption') else ''} + + if len(self._tag_set) > 1: + self.has_tags = True + else: + self.has_tags = False + + self.filetype = 'page' + + @property + def regions(self): + return self._regions + + @property + def lines(self): + return self._lines + + @property + def reading_orders(self): + return self._orders + + def get_sorted_lines(self, ro='line_implicit'): + """ + Returns ordered baselines from particular reading order. + """ + if ro not in self.reading_orders: + raise ValueError(f'Unknown reading order {ro}') + + def _traverse_ro(el): + _ro = [] + if isinstance(el, list): + _ro = [_traverse_ro(x) for x in el] else: - logger.info('TextLine {} without baseline'.format(line.get('id'))) - continue - text = '' - manual_transcription = line.find('./{*}TextEquiv') - if manual_transcription is not None: - transcription = manual_transcription + # if line directly append to ro + if el in self.lines: + return self.lines[el] + # substitute lines if region in RO + elif el in [reg['id'] for regs in self.regions.values() for reg in regs]: + _ro.extend(self.get_sorted_lines_by_region(el)) + else: + raise ValueError(f'Invalid reading order {ro}') + return _ro + + _ro = self.reading_orders[ro] + return _traverse_ro(_ro['order']) + + def get_sorted_regions(self, ro='region_implicit'): + """ + Returns ordered regions from particular reading order. + """ + if ro not in self.reading_orders: + raise ValueError(f'Unknown reading order {ro}') + + regions = {reg.id: key for key, regs in self.regions.items() for reg in regs} + + def _traverse_ro(el): + _ro = [] + if isinstance(el, list): + _ro = [_traverse_ro(x) for x in el] else: - transcription = line - for el in transcription.findall('.//{*}Unicode'): - if el.text: - text += el.text - # retrieve line tags if custom string is set and contains - tags = {'type': 'default'} - split_type = None - custom_str = line.get('custom') - if custom_str: - cs = _parse_page_custom(custom_str) - if 'structure' in cs and 'type' in cs['structure']: - tags['type'] = cs['structure']['type'] - tag_set.add(tags['type']) - # retrieve data split if encoded in custom string. - if 'split' in cs and 'type' in cs['split'] and cs['split']['type'] in ['train', 'validation', 'test']: - split_type = cs['split']['type'] - tags['split'] = split_type - tag_set.add(split_type) - - data['lines'].append({'baseline': baseline, - 'boundary': boundary, - 'text': text, - 'split': split_type, - 'tags': tags}) - if len(tag_set) > 1: - data['script_detection'] = True - else: - data['script_detection'] = False - return data - - -def parse_alto(filename: Union[str, PathLike]) -> Dict[str, Any]: - """ - Parses an ALTO file, returns the baselines defined in it, and loads the - referenced image. - - Args: - filename: path to an ALTO file. - - Returns: - A dict:: - - {'image': impath, - 'lines': [{'boundary': [[x0, y0], ...], - 'baseline': [[x0, y0], ...], - 'text': apdjfqpf', - 'tags': {'type': 'default', ...}}, - ... - {...}], - 'regions': {'region_type_0': [[[x0, y0], ...], ...], ...}} - """ - def _parse_pointstype(coords: str) -> Sequence[Tuple[float, float]]: + # if region directly append to ro + if el in regions.keys(): + return [reg for reg in self.regions[regions[el]] if reg.id == el][0] + else: + raise ValueError(f'Invalid reading order {ro}') + return _ro + + _ro = self.reading_orders[ro] + return _traverse_ro(_ro['order']) + + def get_sorted_lines_by_region(self, region, ro='line_implicit'): + """ + Returns ordered lines in region. + """ + if ro not in self.reading_orders: + raise ValueError(f'Unknown reading order {ro}') + if self.reading_orders[ro]['is_total'] is False: + raise ValueError('Fetching lines by region of a non-total order is not supported') + lines = [(id, line) for id, line in self._lines.items() if line.regions[0] == region] + for line in lines: + if line[0] not in self.reading_orders[ro]['order']: + raise ValueError('Fetching lines by region is only possible for flat orders') + return sorted(lines, key=lambda k: self.reading_orders[ro]['order'].index(k[0])) + + def get_lines_by_tag(self, key, value): + return {k: v for k, v in self._lines.items() if v.tags.get(key) == value} + + def get_lines_by_split(self, split: Literal['train', 'validation', 'test']): + return {k: v for k, v in self._lines.items() if v.tags.get('split') == split} + + @property + def tags(self): + return self._tag_set + + @property + def splits(self): + return self._split_set + + @staticmethod + def _parse_alto_pointstype(coords: str) -> Sequence[Tuple[float, float]]: """ ALTO's PointsType is underspecified so a variety of serializations are valid: @@ -353,126 +556,43 @@ def _parse_pointstype(coords: str) -> Sequence[Tuple[float, float]]: pts = zip(points[::2], points[1::2]) return [k for k, g in groupby(pts)] - with open(filename, 'rb') as fp: - base_dir = Path(filename).parent - try: - doc = etree.parse(fp) - except etree.XMLSyntaxError as e: - raise KrakenInputException('Parsing {} failed: {}'.format(filename, e)) - image = doc.find('.//{*}fileName') - if image is None or not image.text: - raise KrakenInputException('No valid filename found in ALTO file') - lines = doc.findall('.//{*}TextLine') - data = {'image': base_dir.joinpath(image.text), - 'lines': [], - 'type': 'baselines', - 'base_dir': None, - 'regions': {}} - # find all image regions - regions = [] - for x in alto_regions.keys(): - regions.extend(doc.findall('./{{*}}Layout/{{*}}Page/{{*}}PrintSpace/{{*}}{}'.format(x))) - # find overall dimensions to filter out dummy TextBlocks - ps = doc.find('./{*}Layout/{*}Page/{*}PrintSpace') - x_min = int(float(ps.get('HPOS'))) - y_min = int(float(ps.get('VPOS'))) - width = int(float(ps.get('WIDTH'))) - height = int(float(ps.get('HEIGHT'))) - page_boundary = [(x_min, y_min), - (x_min, y_min + height), - (x_min + width, y_min + height), - (x_min + width, y_min)] - - # parse tagrefs - cls_map = {} - tags = doc.find('.//{*}Tags') - if tags is not None: - for x in ['StructureTag', 'LayoutTag', 'OtherTag']: - for tag in tags.findall('./{{*}}{}'.format(x)): - cls_map[tag.get('ID')] = (x[:-3].lower(), tag.get('LABEL')) - # parse region type and coords - region_data = defaultdict(list) - for region in regions: - # try to find shape object - coords = region.find('./{*}Shape/{*}Polygon') - if coords is not None: - boundary = _parse_pointstype(coords.get('POINTS')) - elif (region.get('HPOS') is not None and region.get('VPOS') is not None and - region.get('WIDTH') is not None and region.get('HEIGHT') is not None): - # use rectangular definition - x_min = int(float(region.get('HPOS'))) - y_min = int(float(region.get('VPOS'))) - width = int(float(region.get('WIDTH'))) - height = int(float(region.get('HEIGHT'))) - boundary = [(x_min, y_min), - (x_min, y_min + height), - (x_min + width, y_min + height), - (x_min + width, y_min)] - else: - continue - rtype = region.get('TYPE') - # fall back to default region type if nothing is given - tagrefs = region.get('TAGREFS') - if tagrefs is not None and rtype is None: - for tagref in tagrefs.split(): - ttype, rtype = cls_map.get(tagref, (None, None)) - if rtype is not None and ttype: - break - if rtype is None: - rtype = alto_regions[region.tag.split('}')[-1]] - if boundary == page_boundary and rtype == 'text': - logger.info('Skipping TextBlock with same size as page image.') - continue - region_data[rtype].append(boundary) - data['regions'] = region_data - - tag_set = set(('default',)) - for line in lines: - if line.get('BASELINE') is None: - logger.info('TextLine {} without baseline'.format(line.get('ID'))) - continue - pol = line.find('./{*}Shape/{*}Polygon') - boundary = None - if pol is not None: - try: - boundary = _parse_pointstype(pol.get('POINTS')) - except ValueError: - logger.info('TextLine {} without polygon'.format(line.get('ID'))) - else: - logger.info('TextLine {} without polygon'.format(line.get('ID'))) + @staticmethod + def _parse_page_custom(s): + o = {} + s = s.strip() + l_chunks = [l_chunk for l_chunk in s.split('}') if l_chunk.strip()] + if l_chunks: + for chunk in l_chunks: + tag, vals = chunk.split('{') + tag_vals = {} + vals = [val.strip() for val in vals.split(';') if val.strip()] + for val in vals: + key, *val = val.split(':') + tag_vals[key] = ":".join(val) + o[tag.strip()] = tag_vals + return o - baseline = None - try: - baseline = _parse_pointstype(line.get('BASELINE')) - except ValueError: - logger.info('TextLine {} without baseline'.format(line.get('ID'))) - - text = '' - for el in line.xpath(".//*[local-name() = 'String'] | .//*[local-name() = 'SP']"): - text += el.get('CONTENT') if el.get('CONTENT') else ' ' - # find line type - tags = {'type': 'default'} - split_type = None - tagrefs = line.get('TAGREFS') - if tagrefs is not None: - for tagref in tagrefs.split(): - ttype, ltype = cls_map.get(tagref, (None, None)) - if ltype is not None: - tag_set.add(ltype) - if ttype == 'other': - tags['type'] = ltype - else: - tags[ttype] = ltype - if ltype in ['train', 'validation', 'test']: - split_type = ltype - data['lines'].append({'baseline': baseline, - 'boundary': boundary, - 'text': text, - 'tags': tags, - 'split': split_type}) - - if len(tag_set) > 1: - data['tags'] = True - else: - data['tags'] = False - return data + @staticmethod + def _parse_page_coords(coords): + points = [x for x in coords.split(' ')] + points = [int(c) for point in points for c in point.split(',')] + pts = zip(points[::2], points[1::2]) + return [k for k, g in groupby(pts)] + + def __str__(self): + return f'XMLPage {self.filename} (format: {self.filetype}, image: {self.imagename})' + + def __repr__(self): + return f'XMLPage(filename={self.filename}, filetype={self.filetype})' + + def to_container(self) -> Segmentation: + """ + Returns a Segmentation object. + """ + return Segmentation(type='baselines', + imagename=self.imagename, + text_direction='horizontal_lr', + script_detection=True, + lines=self.get_sorted_lines(), + regions=self._regions, + line_orders=None) diff --git a/kraken/linegen.py b/kraken/linegen.py index d53daf117..6ea15aa8d 100644 --- a/kraken/linegen.py +++ b/kraken/linegen.py @@ -38,8 +38,6 @@ from scipy.ndimage.interpolation import affine_transform, geometric_transform from PIL import Image, ImageOps -from typing import AnyStr - import logging import ctypes import ctypes.util @@ -104,7 +102,7 @@ class ensureBytes(object): bytes. """ @classmethod - def from_param(cls, value: AnyStr) -> bytes: + def from_param(cls, value: str) -> bytes: if isinstance(value, bytes): return value else: diff --git a/kraken/pageseg.py b/kraken/pageseg.py index 3f52ba339..b98b175b7 100644 --- a/kraken/pageseg.py +++ b/kraken/pageseg.py @@ -19,6 +19,8 @@ Layout analysis methods. """ +import PIL +import uuid import logging import numpy as np @@ -26,12 +28,13 @@ from scipy.ndimage.filters import (gaussian_filter, uniform_filter, maximum_filter) +from kraken.containers import Segmentation, BBoxLine + from kraken.lib import morph, sl from kraken.lib.util import pil2array, is_bitonal, get_im_str from kraken.lib.exceptions import KrakenInputException from kraken.lib.segmentation import reading_order, topsort - __all__ = ['segment'] logger = logging.getLogger(__name__) @@ -301,7 +304,7 @@ def rotate_lines(lines: np.ndarray, angle: float, offset: int) -> np.ndarray: return np.column_stack((x.flatten(), y.flatten())).reshape(-1, 4) -def segment(im, +def segment(im: PIL.Image.Image, text_direction: str = 'horizontal-lr', scale: Optional[float] = None, maxcolseps: float = 2, @@ -309,7 +312,7 @@ def segment(im, no_hlines: bool = True, pad: Union[int, Tuple[int, int]] = 0, mask: Optional[np.ndarray] = None, - reading_order_fn: Callable = reading_order) -> Dict[str, Any]: + reading_order_fn: Callable = reading_order) -> Segmentation: """ Segments a page into text lines. @@ -336,12 +339,9 @@ def segment(im, direction in (`rl`, `lr`). Returns: - A dictionary containing the text direction and a list of reading order - sorted bounding boxes under the key 'boxes': - - .. code-block:: - - {'text_direction': '$dir', 'boxes': [(x1, y1, x2, y2),...]} + A :class:`kraken.containers.Segmentation` class containing reading + order sorted bounding box-type lines as + :class:`kraken.containers.BBoxLine` records. Raises: KrakenInputException: if the input image is not binarized or the text @@ -423,7 +423,12 @@ def segment(im, if isinstance(pad, int): pad = (pad, pad) lines = [(max(x[0]-pad[0], 0), x[1], min(x[2]+pad[1], im.size[0]), x[3]) for x in lines] - - return {'text_direction': text_direction, - 'boxes': rotate_lines(lines, 360-angle, offset).tolist(), - 'script_detection': False} + lines = [BBoxLine(id=str(uuid.uuid4()), bbox=line) for line in rotate_lines(lines, 360-angle, offset).tolist()] + + return Segmentation(text_direction=text_direction, + imagename=getattr(im, 'filename', None), + type='bbox', + regions=None, + line_orders=None, + lines=lines, + script_detection=False) diff --git a/kraken/rpred.py b/kraken/rpred.py index ae80ffa63..823cad5ea 100644 --- a/kraken/rpred.py +++ b/kraken/rpred.py @@ -19,392 +19,39 @@ Generators for recognition on lines images. """ import logging +import dataclasses import numpy as np -import bidi.algorithm as bd -from abc import ABC, abstractmethod from PIL import Image from functools import partial from collections import defaultdict from typing import List, Tuple, Optional, Generator, Union, Dict, Sequence +from kraken.containers import BaselineOCRRecord, BBoxOCRRecord, ocr_record, Segmentation from kraken.lib.util import get_im_str, is_bitonal from kraken.lib.models import TorchSeqRecognizer -from kraken.lib.segmentation import extract_polygons, compute_polygon_section +from kraken.lib.segmentation import extract_polygons from kraken.lib.exceptions import KrakenInputException from kraken.lib.dataset import ImageInputTransforms import copy -__all__ = ['ocr_record', 'BaselineOCRRecord', 'BBoxOCRRecord', 'mm_rpred', 'rpred'] +__all__ = ['mm_rpred', 'rpred'] logger = logging.getLogger(__name__) -class ocr_record(ABC): - """ - A record object containing the recognition result of a single line - """ - base_dir = None - - def __init__(self, - prediction: str, - cuts: Sequence[Union[Tuple[int, int], Tuple[Tuple[int, int], - Tuple[int, int], - Tuple[int, int], - Tuple[int, int]]]], - confidences: Sequence[float], - display_order: bool = True) -> None: - self._prediction = prediction - self._cuts = cuts - self._confidences = confidences - self._display_order = display_order - - @property - @abstractmethod - def type(self): - pass - - def __len__(self) -> int: - return len(self._prediction) - - def __str__(self) -> str: - return self._prediction - - @property - def prediction(self) -> str: - return self._prediction - - @property - def cuts(self) -> Sequence: - return self._cuts - - @property - def confidences(self) -> List[float]: - return self._confidences - - def __iter__(self): - self.idx = -1 - return self - - @abstractmethod - def __next__(self) -> Tuple[str, - Union[Sequence[Tuple[int, int]], - Tuple[Tuple[int, int], - Tuple[int, int], - Tuple[int, int], - Tuple[int, int]]], - float]: - pass - - @abstractmethod - def __getitem__(self, key: Union[int, slice]): - pass - - @abstractmethod - def display_order(self, base_dir) -> 'ocr_record': - pass - - @abstractmethod - def logical_order(self, base_dir) -> 'ocr_record': - pass - - -class BaselineOCRRecord(ocr_record): - """ - A record object containing the recognition result of a single line in - baseline format. - - Attributes: - type: 'baselines' to indicate a baseline record - prediction: The text predicted by the network as one continuous string. - cuts: The absolute bounding polygons for each code point in prediction - as a list of tuples [(x0, y0), (x1, y2), ...]. - confidences: A list of floats indicating the confidence value of each - code point. - - Notes: - When slicing the record the behavior of the cuts is changed from - earlier versions of kraken. Instead of returning per-character bounding - polygons a single polygons section of the line bounding polygon - starting at the first and extending to the last code point emitted by - the network is returned. This aids numerical stability when computing - aggregated bounding polygons such as for words. Individual code point - bounding polygons are still accessible through the `cuts` attribute or - by iterating over the record code point by code point. - """ - type = 'baselines' - - def __init__(self, prediction: str, - cuts: Sequence[Tuple[int, int]], - confidences: Sequence[float], - line: Dict[str, List], - display_order: bool = True) -> None: - super().__init__(prediction, cuts, confidences, display_order) - if 'baseline' not in line: - raise TypeError('Invalid argument type (non-baseline line)') - self.tags = None if 'tags' not in line else line['tags'] - self.line = line['boundary'] - self.baseline = line['baseline'] - - def __repr__(self) -> str: - return f'pred: {self.prediction} baseline: {self.baseline} boundary: {self.line} confidences: {self.confidences}' - - def __next__(self) -> Tuple[str, int, float]: - if self.idx + 1 < len(self): - self.idx += 1 - return (self.prediction[self.idx], - compute_polygon_section(self.baseline, - self.line, - self.cuts[self.idx][0], - self.cuts[self.idx][1]), - self.confidences[self.idx]) - else: - raise StopIteration - - def _get_raw_item(self, key: int): - if key < 0: - key += len(self) - if key >= len(self): - raise IndexError('Index (%d) is out of range' % key) - return (self.prediction[key], - self._cuts[key], - self.confidences[key]) - - def __getitem__(self, key: Union[int, slice]): - if isinstance(key, slice): - recs = [self._get_raw_item(i) for i in range(*key.indices(len(self)))] - prediction = ''.join([x[0] for x in recs]) - flat_offsets = sum((tuple(x[1]) for x in recs), ()) - cut = compute_polygon_section(self.baseline, - self.line, - min(flat_offsets), - max(flat_offsets)) - confidence = np.mean([x[2] for x in recs]) - return (prediction, cut, confidence) - - elif isinstance(key, int): - pred, cut, confidence = self._get_raw_item(key) - return (pred, - compute_polygon_section(self.baseline, self.line, cut[0], cut[1]), - confidence) - else: - raise TypeError('Invalid argument type') - - @property - def cuts(self) -> Sequence[Tuple[int, int]]: - return tuple([compute_polygon_section(self.baseline, self.line, cut[0], cut[1]) for cut in self._cuts]) - - def logical_order(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': - """ - Returns the OCR record in Unicode logical order, i.e. in the order the - characters in the line would be read by a human. - - Args: - base_dir: An optional string defining the base direction (also - called paragraph direction) for the BiDi algorithm. Valid - values are 'L' or 'R'. If None is given the default - auto-resolution will be used. - """ - if self._display_order: - return self._reorder(base_dir) - else: - return self - - def display_order(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': - """ - Returns the OCR record in Unicode display order, i.e. ordered from left - to right inside the line. - - Args: - base_dir: An optional string defining the base direction (also - called paragraph direction) for the BiDi algorithm. Valid - values are 'L' or 'R'. If None is given the default - auto-resolution will be used. - """ - if self._display_order: - return self - else: - return self._reorder(base_dir) - - def _reorder(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': - """ - Reorder the record using the BiDi algorithm. - """ - storage = bd.get_empty_storage() - - if base_dir not in ('L', 'R'): - base_level = bd.get_base_level(self._prediction) - else: - base_level = {'L': 0, 'R': 1}[base_dir] - - storage['base_level'] = base_level - storage['base_dir'] = ('L', 'R')[base_level] - - bd.get_embedding_levels(self._prediction, storage) - bd.explicit_embed_and_overrides(storage) - bd.resolve_weak_types(storage) - bd.resolve_neutral_types(storage, False) - bd.resolve_implicit_levels(storage, False) - for i, j in enumerate(zip(self._prediction, self._cuts, self._confidences)): - storage['chars'][i]['record'] = j - bd.reorder_resolved_levels(storage, False) - bd.apply_mirroring(storage, False) - prediction = '' - cuts = [] - confidences = [] - for ch in storage['chars']: - # code point may have been mirrored - prediction = prediction + ch['ch'] - cuts.append(ch['record'][1]) - confidences.append(ch['record'][2]) - line = {'boundary': self.line, 'baseline': self.baseline} - rec = BaselineOCRRecord(prediction, cuts, confidences, line) - rec.tags = self.tags - rec.base_dir = base_dir - rec._display_order = not self._display_order - return rec - - -class BBoxOCRRecord(ocr_record): - """ - A record object containing the recognition result of a single line in - bbox format. - """ - type = 'box' - - def __init__(self, prediction: str, - cuts: Sequence[Tuple[Tuple[int, int], - Tuple[int, int], - Tuple[int, int], - Tuple[int, int]]], - confidences: Sequence[float], - line: Tuple[Tuple[int, int], - Tuple[int, int], - Tuple[int, int], - Tuple[int, int]], - display_order: bool = True) -> None: - super().__init__(prediction, cuts, confidences, display_order) - if 'baseline' in line: - raise TypeError('Invalid argument type (baseline line)') - self.line = line - - def __repr__(self) -> str: - return f'pred: {self.prediction} line: {self.line} confidences: {self.confidences}' - - def __next__(self) -> Tuple[str, int, float]: - if self.idx + 1 < len(self): - self.idx += 1 - return (self.prediction[self.idx], - self.cuts[self.idx], - self.confidences[self.idx]) - else: - raise StopIteration - - def _get_raw_item(self, key: int): - if key < 0: - key += len(self) - if key >= len(self): - raise IndexError('Index (%d) is out of range' % key) - return (self.prediction[key], - self.cuts[key], - self.confidences[key]) - - def __getitem__(self, key: Union[int, slice]): - if isinstance(key, slice): - recs = [self._get_raw_item(i) for i in range(*key.indices(len(self)))] - prediction = ''.join([x[0] for x in recs]) - box = [x[1] for x in recs] - flat_box = [point for pol in box for point in pol] - flat_box = [x for point in flat_box for x in point] - min_x, max_x = min(flat_box[::2]), max(flat_box[::2]) - min_y, max_y = min(flat_box[1::2]), max(flat_box[1::2]) - cut = ((min_x, min_y), (max_x, min_y), (max_x, max_y), (min_x, max_y)) - confidence = np.mean([x[2] for x in recs]) - return (prediction, cut, confidence) - elif isinstance(key, int): - return self._get_raw_item(key) - else: - raise TypeError('Invalid argument type') - - def logical_order(self, base_dir: Optional[str] = None) -> 'BBoxOCRRecord': - """ - Returns the OCR record in Unicode logical order, i.e. in the order the - characters in the line would be read by a human. - - Args: - base_dir: An optional string defining the base direction (also - called paragraph direction) for the BiDi algorithm. Valid - values are 'L' or 'R'. If None is given the default - auto-resolution will be used. - """ - if self._display_order: - return self._reorder(base_dir) - else: - return self - - def display_order(self, base_dir: Optional[str] = None) -> 'BBoxOCRRecord': - """ - Returns the OCR record in Unicode display order, i.e. ordered from left - to right inside the line. - - Args: - base_dir: An optional string defining the base direction (also - called paragraph direction) for the BiDi algorithm. Valid - values are 'L' or 'R'. If None is given the default - auto-resolution will be used. - """ - if self._display_order: - return self - else: - return self._reorder(base_dir) - - def _reorder(self, base_dir: Optional[str] = None) -> 'BBoxOCRRecord': - storage = bd.get_empty_storage() - - if base_dir not in ('L', 'R'): - base_level = bd.get_base_level(self.prediction) - else: - base_level = {'L': 0, 'R': 1}[base_dir] - - storage['base_level'] = base_level - storage['base_dir'] = ('L', 'R')[base_level] - - bd.get_embedding_levels(self.prediction, storage) - bd.explicit_embed_and_overrides(storage) - bd.resolve_weak_types(storage) - bd.resolve_neutral_types(storage, False) - bd.resolve_implicit_levels(storage, False) - for i, j in enumerate(zip(self.prediction, self.cuts, self.confidences)): - storage['chars'][i]['record'] = j - bd.reorder_resolved_levels(storage, False) - bd.apply_mirroring(storage, False) - prediction = '' - cuts = [] - confidences = [] - for ch in storage['chars']: - # code point may have been mirrored - prediction = prediction + ch['ch'] - cuts.append(ch['record'][1]) - confidences.append(ch['record'][2]) - # carry over whole line information - rec = BBoxOCRRecord(prediction, cuts, confidences, self.line) - rec.base_dir = base_dir - rec._display_order = not self._display_order - return rec - - class mm_rpred(object): """ Multi-model version of kraken.rpred.rpred """ def __init__(self, - nets: Dict[str, TorchSeqRecognizer], + nets: Dict[Tuple[str, str], TorchSeqRecognizer], im: Image.Image, - bounds: dict, + bounds: Segmentation, pad: int = 16, bidi_reordering: Union[bool, str] = True, - tags_ignore: Optional[List[str]] = None) -> Generator[ocr_record, None, None]: + tags_ignore: Optional[List[Tuple[str, str]]] = None) -> Generator[ocr_record, None, None]: """ Multi-model version of kraken.rpred.rpred. @@ -413,20 +60,19 @@ def __init__(self, these lines. Args: - nets (dict): A dict mapping tag values to TorchSegRecognizer - objects. Recommended to be an defaultdict. - im (PIL.Image.Image): Image to extract text from - bounds (dict): A dictionary containing a 'boxes' entry - with a list of lists of coordinates (script, (x0, y0, - x1, y1)) of a text line in the image and an entry - 'text_direction' containing - 'horizontal-lr/rl/vertical-lr/rl'. - pad (int): Extra blank padding to the left and right of text line - bidi_reordering (bool|str): Reorder classes in the ocr_record according to - the Unicode bidirectional algorithm for - correct display. Set to L|R to - override default text direction. - tags_ignore (list): List of tag values to ignore during recognition + nets: A dict mapping tag key-value pairs to TorchSegRecognizer + objects. Recommended to be an defaultdict. + im: Image to extract text from + bounds: A Segmentation data class containing either bounding box or + baseline type segmentation. + pad: Extra blank padding to the left and right of text line + bidi_reordering: Reorder classes in the ocr_record according to the + Unicode bidirectional algorithm for correct + display. Set to L|R to override default text + direction. + tags_ignore: List of tag key-value pairs to ignore during + recognition + Yields: An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. @@ -445,36 +91,34 @@ def __init__(self, if not tags_ignore: tags_ignore = [] - if ('type' in bounds and bounds['type'] not in seg_types) or len(seg_types) > 1: + if bounds.type not in seg_types or len(seg_types) > 1: logger.warning(f'Recognizers with segmentation types {seg_types} will be ' - f'applied to segmentation of type {bounds["type"] if "type" in bounds else None}. ' + f'applied to segmentation of type {bounds.type}. ' f'This will likely result in severely degraded performace') one_channel_modes = set(recognizer.nn.one_channel_mode for recognizer in nets.values()) if '1' in one_channel_modes and len(one_channel_modes) > 1: raise KrakenInputException('Mixing binary and non-binary recognition models is not supported.') elif '1' in one_channel_modes and not is_bitonal(im): logger.warning('Running binary models on non-binary input image ' - '(mode {}). This will result in severely degraded ' - 'performance'.format(im.mode)) - if 'type' in bounds and bounds['type'] == 'baselines': + f'(mode {im.mode}). This will result in severely degraded ' + 'performance') + + self.len = len(bounds.lines) + self.line_iter = iter(bounds.lines) + + if bounds.type == 'baselines': valid_norm = False - self.len = len(bounds['lines']) - self.seg_key = 'lines' self.next_iter = self._recognize_baseline_line - self.line_iter = iter(bounds['lines']) - tags = set() - for x in bounds['lines']: - tags.update(x['tags'].values()) else: valid_norm = True - self.len = len(bounds['boxes']) - self.seg_key = 'boxes' self.next_iter = self._recognize_box_line - self.line_iter = iter(bounds['boxes']) - tags = set(x[0] for line in bounds['boxes'] for x in line) + + tags = set() + for x in bounds.lines: + tags.update(x.tags.items()) im_str = get_im_str(im) - logger.info('Running {} multi-script recognizers on {} with {} lines'.format(len(nets), im_str, self.len)) + logger.info(f'Running {len(nets)} multi-script recognizers on {im_str} with {self.len} lines') filtered_tags = [] miss = [] @@ -486,12 +130,12 @@ def __init__(self, tags = filtered_tags if miss: - raise KrakenInputException('Missing models for tags {}'.format(set(miss))) + raise KrakenInputException(f'Missing models for tags {set(miss)}') # build dictionary for line preprocessing self.ts = {} for tag in tags: - logger.debug('Loading line transforms for {}'.format(tag)) + logger.debug(f'Loading line transforms for {tag}') network = nets[tag] batch, channels, height, width = network.nn.input self.ts[tag] = ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm) @@ -504,70 +148,75 @@ def __init__(self, self.tags_ignore = tags_ignore def _recognize_box_line(self, line): - flat_box = [point for box in line['boxes'][0] for point in box[1]] + flat_box = [point for box in line.bbox for point in box] xmin, xmax = min(flat_box[::2]), max(flat_box[::2]) ymin, ymax = min(flat_box[1::2]), max(flat_box[1::2]) line_bbox = ((xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)) prediction = '' cuts = [] confidences = [] - for tag, (box, coords) in zip(map(lambda x: x[0], line['boxes'][0]), - extract_polygons(self.im, {'text_direction': line['text_direction'], - 'boxes': map(lambda x: x[1], line['boxes'][0])})): - self.box = box - # skip if tag is set to ignore - if self.tags_ignore is not None and tag in self.tags_ignore: - logger.warning(f'Ignoring {tag} line segment.') - continue - # check if boxes are non-zero in any dimension - if 0 in box.size: - logger.warning(f'bbox {coords} with zero dimension. Emitting empty record.') - return BBoxOCRRecord('', (), (), coords) - # try conversion into tensor - try: - logger.debug('Preparing run.') - line = self.ts[tag](box) - except Exception: - logger.warning(f'Conversion of line {coords} failed. Emitting empty record..') - return BBoxOCRRecord('', (), (), coords) - - # check if line is non-zero - if line.max() == line.min(): - logger.warning('Empty run. Emitting empty record.') - return BBoxOCRRecord('', (), (), coords) - - _, net = self._resolve_tags_to_model({'type': tag}, self.nets) - - logger.debug(f'Forward pass with model {tag}.') - preds = net.predict(line.unsqueeze(0))[0] - - # calculate recognized LSTM locations of characters - logger.debug('Convert to absolute coordinates') - # calculate recognized LSTM locations of characters - # scale between network output and network input - self.net_scale = line.shape[2]/net.outputs.shape[2] - # scale between network input and original line - self.in_scale = box.size[0]/(line.shape[2]-2*self.pad) - - pred = ''.join(x[0] for x in preds) - pos = [] - conf = [] - - for _, start, end, c in preds: - if self.bounds['text_direction'].startswith('horizontal'): - xmin = coords[0] + self._scale_val(start, 0, self.box.size[0]) - xmax = coords[0] + self._scale_val(end, 0, self.box.size[0]) - pos.append([[xmin, coords[1]], [xmin, coords[3]], [xmax, coords[3]], [xmax, coords[1]]]) - else: - ymin = coords[1] + self._scale_val(start, 0, self.box.size[1]) - ymax = coords[1] + self._scale_val(end, 0, self.box.size[1]) - pos.append([[coords[0], ymin], [coords[2], ymin], [coords[2], ymax], [coords[0], ymax]]) - conf.append(c) - prediction += pred - cuts.extend(pos) - confidences.extend(conf) - - rec = BBoxOCRRecord(prediction, cuts, confidences, line_bbox) + line.text_direction = self.bounds.text_direction + + if self.tags_ignore is not None: + for tag in line.tags.values(): + if tag in self.tags_ignore: + logger.info(f'Ignoring line segment with tags {line.tags} based on {tag}.') + return BaselineOCRRecord('', [], [], line) + + tag, net = self._resolve_tags_to_model(line.tags, self.nets) + + box, coords = next(extract_polygons(self.im, line)) + self.box = box + + # check if boxes are non-zero in any dimension + if 0 in box.size: + logger.warning(f'bbox {line} with zero dimension. Emitting empty record.') + return BBoxOCRRecord('', (), (), line) + # try conversion into tensor + try: + logger.debug('Preparing run.') + ts_box = self.ts[tag](box) + except Exception: + logger.warning(f'Conversion of line {line} failed. Emitting empty record..') + return BBoxOCRRecord('', (), (), line) + + # check if line is non-zero + if ts_box.max() == ts_box.min(): + logger.warning('Empty run. Emitting empty record.') + return BBoxOCRRecord('', (), (), line) + + _, net = self._resolve_tags_to_model({'type': tag}, self.nets) + + logger.debug(f'Forward pass with model {tag}.') + preds = net.predict(ts_box.unsqueeze(0))[0] + + # calculate recognized LSTM locations of characters + logger.debug('Convert to absolute coordinates') + # calculate recognized LSTM locations of characters + # scale between network output and network input + self.net_scale = ts_box.shape[2]/net.outputs.shape[2] + # scale between network input and original line + self.in_scale = box.size[0]/(ts_box.shape[2]-2*self.pad) + + pred = ''.join(x[0] for x in preds) + pos = [] + conf = [] + + for _, start, end, c in preds: + if self.bounds.text_direction.startswith('horizontal'): + xmin = coords[0] + self._scale_val(start, 0, self.box.size[0]) + xmax = coords[0] + self._scale_val(end, 0, self.box.size[0]) + pos.append([[xmin, coords[1]], [xmin, coords[3]], [xmax, coords[3]], [xmax, coords[1]]]) + else: + ymin = coords[1] + self._scale_val(start, 0, self.box.size[1]) + ymax = coords[1] + self._scale_val(end, 0, self.box.size[1]) + pos.append([[coords[0], ymin], [coords[2], ymin], [coords[2], ymax], [coords[0], ymax]]) + conf.append(c) + prediction += pred + cuts.extend(pos) + confidences.extend(conf) + + rec = BBoxOCRRecord(prediction, cuts, confidences, line) if self.bidi_reordering: logger.debug('BiDi reordering record.') return rec.logical_order(base_dir=self.bidi_reordering if self.bidi_reordering in ('L', 'R') else None) @@ -577,41 +226,43 @@ def _recognize_box_line(self, line): def _recognize_baseline_line(self, line): if self.tags_ignore is not None: - for tag in line['lines'][0]['tags'].values(): + for tag in line.tags.values(): if tag in self.tags_ignore: - logger.info(f'Ignoring line segment with tags {line["lines"][0]["tags"]} based on {tag}.') - return BaselineOCRRecord('', [], [], line['lines'][0]) + logger.info(f'Ignoring line segment with tags {line.tags} based on {tag}.') + return BaselineOCRRecord('', [], [], line) + + seg = dataclasses.replace(self.bounds, lines=[line]) try: - box, coords = next(extract_polygons(self.im, line)) + box, coords = next(extract_polygons(self.im, seg)) except KrakenInputException as e: logger.warning(f'Extracting line failed: {e}') - return BaselineOCRRecord('', [], [], line['lines'][0]) + return BaselineOCRRecord('', [], [], line) self.box = box - tag, net = self._resolve_tags_to_model(coords['tags'], self.nets) + tag, net = self._resolve_tags_to_model(line.tags, self.nets) # check if boxes are non-zero in any dimension if 0 in box.size: - logger.warning(f'bbox {coords} with zero dimension. Emitting empty record.') - return BaselineOCRRecord('', [], [], coords) + logger.warning(f'{line} with zero dimension. Emitting empty record.') + return BaselineOCRRecord('', [], [], line) # try conversion into tensor try: - line = self.ts[tag](box) + ts_box = self.ts[tag](box) except Exception as e: logger.warning(f'Tensor conversion failed with {e}. Emitting empty record.') - return BaselineOCRRecord('', [], [], coords) + return BaselineOCRRecord('', [], [], line) # check if line is non-zero - if line.max() == line.min(): + if ts_box.max() == ts_box.min(): logger.warning('Empty line after tensor conversion. Emitting empty record.') - return BaselineOCRRecord('', [], [], coords) + return BaselineOCRRecord('', [], [], line) - preds = net.predict(line.unsqueeze(0))[0] + preds = net.predict(ts_box.unsqueeze(0))[0] # calculate recognized LSTM locations of characters # scale between network output and network input - self.net_scale = line.shape[2]/net.outputs.shape[2] + self.net_scale = ts_box.shape[2]/net.outputs.shape[2] # scale between network input and original line - self.in_scale = box.size[0]/(line.shape[2]-2*self.pad) + self.in_scale = box.size[0]/(ts_box.shape[2]-2*self.pad) # XXX: fix bounding box calculation ocr_record for multi-codepoint labels. pred = ''.join(x[0] for x in preds) @@ -621,7 +272,7 @@ def _recognize_baseline_line(self, line): pos.append((self._scale_val(start, 0, self.box.size[0]), self._scale_val(end, 0, self.box.size[0]))) conf.append(c) - rec = BaselineOCRRecord(pred, pos, conf, coords) + rec = BaselineOCRRecord(pred, pos, conf, line) if self.bidi_reordering: logger.debug('BiDi reordering record.') return rec.logical_order(base_dir=self.bidi_reordering if self.bidi_reordering in ('L', 'R') else None) @@ -630,9 +281,7 @@ def _recognize_baseline_line(self, line): return rec.display_order(None) def __next__(self): - bound = self.bounds - bound[self.seg_key] = [next(self.line_iter)] - return self.next_iter(bound) + return self.next_iter(next(self.line_iter)) def __iter__(self): return self @@ -646,51 +295,40 @@ def _scale_val(self, val, min_val, max_val): def rpred(network: TorchSeqRecognizer, im: Image.Image, - bounds: dict, + bounds: Segmentation, pad: int = 16, bidi_reordering: Union[bool, str] = True) -> Generator[ocr_record, None, None]: """ Uses a TorchSeqRecognizer and a segmentation to recognize text Args: - network (kraken.lib.models.TorchSeqRecognizer): A TorchSegRecognizer - object - im (PIL.Image.Image): Image to extract text from - bounds (dict): A dictionary containing a 'boxes' entry with a list of - coordinates (x0, y0, x1, y1) of a text line in the image - and an entry 'text_direction' containing - 'horizontal-lr/rl/vertical-lr/rl'. - pad (int): Extra blank padding to the left and right of text line. - Auto-disabled when expected network inputs are incompatible - with padding. - bidi_reordering (bool|str): Reorder classes in the ocr_record according to - the Unicode bidirectional algorithm for correct - display. Set to L|R to change base text - direction. + network: A TorchSegRecognizer object + im: Image to extract text from + bounds: A Segmentation class instance containing either a baseline or + bbox segmentation. + pad: Extra blank padding to the left and right of text line. + Auto-disabled when expected network inputs are incompatible with + padding. + bidi_reordering: Reorder classes in the ocr_record according to the + Unicode bidirectional algorithm for correct display. + Set to L|R to change base text direction. + Yields: An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. """ - bounds = copy.deepcopy(bounds) - if 'boxes' in bounds: - boxes = bounds['boxes'] - rewrite_boxes = [] - for box in boxes: - rewrite_boxes.append([('default', box)]) - bounds['boxes'] = rewrite_boxes - bounds['script_detection'] = True return mm_rpred(defaultdict(lambda: network), im, bounds, pad, bidi_reordering) def _resolve_tags_to_model(tags: Sequence[Dict[str, str]], - model_map: Dict[str, TorchSeqRecognizer], + model_map: Dict[Tuple[str, str], TorchSeqRecognizer], default: Optional[TorchSeqRecognizer] = None) -> TorchSeqRecognizer: """ Resolves a sequence of tags """ - for tag in tags.values(): + for tag in tags.items(): if tag in model_map: return tag, model_map[tag] if default: return next(tags.values()), default - raise KrakenInputException('No model for tags {}'.format(tags)) + raise KrakenInputException(f'No model for tags {tags}') diff --git a/kraken/serialization.py b/kraken/serialization.py index c741923da..ae6679762 100644 --- a/kraken/serialization.py +++ b/kraken/serialization.py @@ -23,7 +23,7 @@ from pkg_resources import get_distribution from collections import Counter -from kraken.rpred import BaselineOCRRecord, BBoxOCRRecord, ocr_record +from kraken.containers import Segmentation, ProcessingStep from kraken.lib.util import make_printable from kraken.lib.segmentation import is_in_region @@ -70,95 +70,92 @@ def max_bbox(boxes: Iterable[Sequence[int]]) -> Tuple[int, int, int, int]: return o -def serialize(records: Sequence[ocr_record], - image_name: Union[PathLike, str] = None, +def serialize(results: Segmentation, image_size: Tuple[int, int] = (0, 0), writing_mode: Literal['horizontal-tb', 'vertical-lr', 'vertical-rl'] = 'horizontal-tb', scripts: Optional[Iterable[str]] = None, - regions: Optional[Dict[str, List[List[Tuple[int, int]]]]] = None, template: [PathLike, str] = 'alto', template_source: Literal['native', 'custom'] = 'native', - processing_steps: Optional[List[Dict[str, Union[Dict, str, float, int, bool]]]] = None) -> str: + processing_steps: Optional[List[ProcessingStep]] = None) -> str: """ - Serializes a list of ocr_records into an output document. + Serializes recognition and segmentation results into an output document. - Serializes a list of predictions and their corresponding positions by doing - some hOCR-specific preprocessing and then renders them through one of - several jinja2 templates. + Serializes a Segmentation container object containing either segmentation + or recognition results into an output document. The rendering is performed + with jinja2 templates that can either be shipped with kraken + (`template_source` == 'native') or custom (`template_source` == 'custom'). Note: Empty records are ignored for serialization purposes. Args: - records: List of kraken.rpred.ocr_record - image_name: Name of the source image + segmentation: Segmentation container object image_size: Dimensions of the source image writing_mode: Sets the principal layout of lines and the direction in which blocks progress. Valid values are horizontal-tb, vertical-rl, and vertical-lr. scripts: List of scripts contained in the OCR records - regions: Dictionary mapping region types to a list of region polygons. template: Selector for the serialization format. May be 'hocr', 'alto', 'page' or any template found in the template directory. If template_source is set to `custom` a path to a template is expected. template_source: Switch to enable loading of custom templates from outside the kraken package. - processing_steps: A list of dictionaries describing the processing kraken performed on the inputs:: - - {'category': 'preprocessing', - 'description': 'natural language description of process', - 'settings': {'arg0': 'foo', 'argX': 'bar'} - } + processing_steps: A list of ProcessingStep container classes describing + the processing kraken performed on the inputs. Returns: The rendered template """ - logger.info(f'Serialize {len(records)} records from {image_name} with template {template}.') + logger.info(f'Serialize {len(results.lines)} records from {results.imagename} with template {template}.') page = {'entities': [], 'size': image_size, - 'name': image_name, + 'name': results.imagename, 'writing_mode': writing_mode, 'scripts': scripts, 'date': datetime.datetime.now(datetime.timezone.utc).isoformat(), - 'base_dir': [rec.base_dir for rec in records][0] if len(records) else None} # type: dict + 'base_dir': [rec.base_dir for rec in results.lines][0] if len(results.lines) else None, + 'seg_type': results.type} # type: dict metadata = {'processing_steps': processing_steps, 'version': get_distribution('kraken').version} seg_idx = 0 char_idx = 0 - region_map = {} - idx = 0 - if regions is not None: - for id, regs in regions.items(): - for reg in regs: - region_map[idx] = (id, geom.Polygon(reg), reg) - idx += 1 # build region and line type dict types = [] - for line in records: - if hasattr(line, 'tags') and line.tags is not None: - types.extend(line.tags.values()) - page['types'] = list(set(types)) - if regions is not None: - page['types'].extend(list(regions.keys())) - - is_in_reg = -1 - for idx, record in enumerate(records): - if record.type == 'baselines': - l_obj = geom.LineString(record.baseline) - else: - l_obj = geom.LineString(record.line) - reg = list(filter(lambda x: is_in_region(l_obj, x[1][1]), region_map.items())) - if len(reg) == 0: - cur_ent = page['entities'] - elif reg[0][0] != is_in_reg: - reg = reg[0] - is_in_reg = reg[0] - region = {'index': reg[0], - 'bbox': [int(x) for x in reg[1][1].bounds], - 'boundary': [list(x) for x in reg[1][2]], - 'region_type': reg[1][0], + for line in results.lines: + if line.tags is not None: + types.extend((k, v) for k, v in line.tags.items()) + page['line_types'] = list(set(types)) + page['region_types'] =[list(results.regions.keys())] + + # map reading orders indices to line IDs + ros = [] + for ro in results.line_orders: + ros.append([results.lines[idx].id for idx in ro]) + page['line_orders'] = ros + + # build region ID to region dict + reg_dict = {} + for key, regs in results.regions.items(): + for reg in regs: + reg_dict[reg.id] = reg + + regs_with_lines = set() + prev_reg = None + for idx, record in enumerate(results.lines): + # line not in region + if len(record.regions) == 0: + cur_ent = page['entitites'] + # line not in same region as previous line + elif prev_reg != record.regions[0]: + prev_reg = record.regions[0] + reg = reg_dict[record.regions[0]] + regs_with_lines.add(reg.id) + region = {'id': reg.id, + 'bbox': max_bbox([reg.boundary]), + 'boundary': [list(x) for x in reg.boundary], + 'tags': reg.tags, 'lines': [], 'type': 'region' } @@ -167,20 +164,19 @@ def serialize(records: Sequence[ocr_record], # set field to indicate the availability of baseline segmentation in # addition to bounding boxes - if record.type == 'baselines': - page['seg_type'] = 'baselines' line = {'index': idx, - 'bbox': max_bbox([record.line]), + 'bbox': max_bbox([record.boundary] if record.type == 'baselines' else record.bbox), 'cuts': record.cuts, 'confidences': record.confidences, 'recognition': [], - 'boundary': [list(x) for x in record.line], + 'boundary': [list(x) for x in record.boundary], 'type': 'line' } - if hasattr(record, 'tags') and record.tags is not None: + if record.tags is not None: line['tags'] = record.tags if record.type == 'baselines': line['baseline'] = [list(x) for x in record.baseline] + splits = regex.split(r'(\s+)', record.prediction) line_offset = 0 logger.debug(f'Record contains {len(splits)} segments') @@ -213,18 +209,19 @@ def serialize(records: Sequence[ocr_record], line_offset += len(segment) cur_ent.append(line) - # No records but there are regions -> serialize all regions - if not records and regions: - logger.debug(f'No lines given but {len(region_map)}. Serialize all regions.') - for reg in region_map.items(): - region = {'index': reg[0], - 'bbox': [int(x) for x in reg[1][1].bounds], - 'boundary': [list(x) for x in reg[1][2]], - 'region_type': reg[1][0], - 'lines': [], - 'type': 'region' - } - page['entities'].append(region) + # serialize all remaining (line-less) regions + for reg_id in regs_with_lines: + reg_dict.pop(reg_id) + logger.debug(f'No lines given but {len(results.regions)}. Serialize all regions.') + for reg in reg_dict.values(): + region = {'id': reg.id, + 'bbox': max_bbox([reg.boundary]), + 'boundary': [list(x) for x in reg.boundary], + 'tags': reg.tags, + 'lines': [], + 'type': 'region' + } + page['entities'].append(region) if template_source == 'native': logger.debug('Initializing native jinja environment.') @@ -246,43 +243,6 @@ def _load_template(name): return tmpl.render(page=page, metadata=metadata) -def serialize_segmentation(segresult: Dict[str, Any], - image_name: Union[PathLike, str] = None, - image_size: Tuple[int, int] = (0, 0), - template: Union[PathLike, str] = 'alto', - template_source: Literal['native', 'custom'] = 'native', - processing_steps: Optional[List[Dict[str, Union[Dict, str, float, int, bool]]]] = None) -> str: - """ - Serializes a segmentation result into an output document. - - Args: - segresult: Result of blla.segment - image_name: Name of the source image - image_size: Dimensions of the source image - template: Selector for the serialization format. Any value accepted by - `serialize` is valid. - template_source: Enables/disables loading of external templates. - - Returns: - (str) rendered template. - """ - if 'type' in segresult and segresult['type'] == 'baselines': - records = [BaselineOCRRecord('', (), (), bl) for bl in segresult['lines']] - else: - records = [] - for line in segresult['boxes']: - xmin, xmax = min(line[::2]), max(line[::2]) - ymin, ymax = min(line[1::2]), max(line[1::2]) - records.append(BBoxOCRRecord('', (), (), ((xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)))) - return serialize(records, - image_name=image_name, - image_size=image_size, - regions=segresult['regions'] if 'regions' in segresult else None, - template=template, - template_source=template_source, - processing_steps=processing_steps) - - def render_report(model: str, chars: int, errors: int, diff --git a/kraken/templates/alto b/kraken/templates/alto index 05d7ab192..ccddf4182 100644 --- a/kraken/templates/alto +++ b/kraken/templates/alto @@ -49,7 +49,7 @@ {% if metadata.processing_steps %} {% for step in metadata.processing_steps %} - + {{ proc_type_table[step.category] }} {{ step.description }} {% for k, v in step.settings.items() %}{{k}}: {{v}}; {% endfor %} @@ -71,10 +71,34 @@ {% endif %} - {% for reg_type in page.types %} - + {% for type, label in page.line_types %} + + {% endfor %} + {% for label in page.region_types %} + {% endfor %} + {% if len(page.line_orders) > 0 %} + + {% if len(page.line_orders) == 1 %} + + {% for id in page.line_orders[0] %} + + {% endfor %} + + {% else %} + + {% for ro in page.line_orders %} + + {% for id in ro %} + + {% endfor %} + + {% endfor %} + + {% endif %} + + {% endif %} diff --git a/kraken/transcribe.py b/kraken/transcribe.py index e9a8067b8..5b39ee2f7 100644 --- a/kraken/transcribe.py +++ b/kraken/transcribe.py @@ -18,8 +18,6 @@ from kraken.lib.exceptions import KrakenInputException from kraken.lib.util import get_im_str -from typing import List - from jinja2 import Environment, PackageLoader from io import BytesIO diff --git a/setup.cfg b/setup.cfg index 60c3994d6..f26ab7957 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,6 +60,7 @@ install_requires = pyarrow pytorch-lightning~=2.0.0 torchmetrics>=0.10.0 + threadpoolctl~=3.2.0 rich [options.extras_require] diff --git a/tests/resources/bsb00084914_00007.xml b/tests/resources/bsb00084914_00007.xml new file mode 100644 index 000000000..311751ad1 --- /dev/null +++ b/tests/resources/bsb00084914_00007.xml @@ -0,0 +1,1074 @@ + + + + pixel + + bsb00084914_00007.jpg + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/resources/cPAS-2000.xml b/tests/resources/cPAS-2000.xml new file mode 100644 index 000000000..d9f844121 --- /dev/null +++ b/tests/resources/cPAS-2000.xml @@ -0,0 +1,410 @@ + + + + TRP + 2018-12-24T11:28:19+07:00 + 2019-02-05T09:16:48Z + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py new file mode 100644 index 000000000..fff086c12 --- /dev/null +++ b/tests/test_arrow_dataset.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +import unittest +import json + +import kraken + +from pytest import raises +from pathlib import Path + +from kraken.lib import xml +from kraken.lib.arrow_dataset import build_binary_dataset + +thisfile = Path(__file__).resolve().parent +resources = thisfile / 'resources' + +class TestKrakenArrowCompilation(unittest.TestCase): + """ + Tests for binary datasets + """ + def setUp(self): + self.xml = resources / '170025120000003,0074.xml' + self.bls = xml.XMLPage(self.xml) + self.box_lines = [resources / '000236.png'] + + def test_build_path_dataset(self): + pass + + def test_build_xml_dataset(self): + pass + + def test_build_obj_dataset(self): + pass + + def test_build_empty_dataset(self): + pass diff --git a/tests/test_xml.py b/tests/test_xml.py new file mode 100644 index 000000000..611959d2e --- /dev/null +++ b/tests/test_xml.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +import json +import unittest +import tempfile +import numpy as np + +from pathlib import Path +from pytest import raises + +from kraken.lib import xml + +thisfile = Path(__file__).resolve().parent +resources = thisfile / 'resources' + +class TestXMLParser(unittest.TestCase): + """ + Tests XML (ALTO/PAGE) parsing + """ + def setUp(self): + self.page_doc = resources / 'cPAS-2000.xml' + self.alto_doc = resources / 'bsb00084914_00007.xml' + + def test_page_parsing(self): + """ + Test parsing of PAGE XML files with reading order. + """ + doc = xml.XMLPage(self.page_doc, filetype='page') + self.assertEqual(len(doc.baselines), 97) + self.assertEqual(len([item for x in doc.regions.values() for item in x]), 4) + + def test_alto_parsing(self): + """ + Test parsing of ALTO XML files with reading order. + """ + doc = xml.XMLPage(self.alto_doc, filetype='alto') + + def test_auto_parsing(self): + """ + Test parsing of PAGE and ALTO XML files with auto-format determination. + """ + doc = xml.XMLPage(self.page_doc, filetype='xml') + self.assertEqual(doc.filetype, 'page') + doc = xml.XMLPage(self.alto_doc, filetype='xml') + self.assertEqual(doc.filetype, 'alto') + + def test_failure_page_alto_parsing(self): + """ + Test that parsing ALTO files with PAGE as format fails. + """ + with raises(ValueError): + xml.XMLPage(self.alto_doc, filetype='page') + + def test_failure_alto_page_parsing(self): + """ + Test that parsing PAGE files with ALTO as format fails. + """ + with raises(ValueError): + xml.XMLPage(self.page_doc, filetype='alto') +