diff --git a/kraken/contrib/extract_lines.py b/kraken/contrib/extract_lines.py index 95233263c..78f779af9 100755 --- a/kraken/contrib/extract_lines.py +++ b/kraken/contrib/extract_lines.py @@ -9,8 +9,9 @@ 'link to source images.') @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), help='Baseline detection model to use. Overrides format type and expects image files as input.') +@click.option('--legacy-polygons', is_flag=True, help='Use the legacy polygon extractor.') @click.argument('files', nargs=-1) -def cli(format_type, model, files): +def cli(format_type, model, legacy_polygons, files): """ A small script extracting rectified line polygons as defined in either ALTO or PageXML files or run a model to do the same. @@ -37,7 +38,7 @@ def cli(format_type, model, files): data = xml.XMLPage(doc, format_type) if len(data.lines) > 0: bounds = data.to_container() - for idx, (im, box) in enumerate(segmentation.extract_polygons(Image.open(bounds.imagename), bounds)): + for idx, (im, box) in enumerate(segmentation.extract_polygons(Image.open(bounds.imagename), bounds, legacy=legacy_polygons)): click.echo('.', nl=False) im.save('{}.{}.jpg'.format(splitext(bounds.imagename)[0], idx)) with open('{}.{}.gt.txt'.format(splitext(bounds.imagename)[0], idx), 'w') as fp: @@ -61,7 +62,7 @@ def cli(format_type, model, files): click.echo(f'Processing {doc} ', nl=False) full_im = Image.open(doc) bounds = blla.segment(full_im, model=net) - for idx, (im, box) in enumerate(segmentation.extract_polygons(full_im, bounds)): + for idx, (im, box) in enumerate(segmentation.extract_polygons(full_im, bounds, legacy=legacy_polygons)): click.echo('.', nl=False) im.save('{}.{}.jpg'.format(splitext(doc)[0], idx)) diff --git a/kraken/ketos/dataset.py b/kraken/ketos/dataset.py index a4df23400..06154d78a 100644 --- a/kraken/ketos/dataset.py +++ b/kraken/ketos/dataset.py @@ -55,9 +55,11 @@ help='Minimum number of records per RecordBatch written to the ' 'output file. Larger batches require more transient memory ' 'but slightly improve reading performance.') +@click.option('--legacy-polygons', show_default=True, default=False, is_flag=True, + help='Use the old polygon extractor.') @click.argument('ground_truth', nargs=-1, type=click.Path(exists=True, dir_okay=False)) def compile(ctx, output, workers, format_type, files, random_split, force_type, - save_splits, skip_empty_lines, recordbatch_size, ground_truth): + save_splits, skip_empty_lines, recordbatch_size, ground_truth, legacy_polygons): """ Precompiles a binary dataset from a collection of XML files. """ @@ -91,6 +93,7 @@ def compile(ctx, output, workers, format_type, files, random_split, force_type, force_type, recordbatch_size, skip_empty_lines, - lambda advance, total: progress.update(extract_task, total=total, advance=advance)) + lambda advance, total: progress.update(extract_task, total=total, advance=advance), + legacy_polygons=legacy_polygons) message(f'Output file written to {output}') diff --git a/kraken/ketos/pretrain.py b/kraken/ketos/pretrain.py index a8c404dac..5d6055849 100644 --- a/kraken/ketos/pretrain.py +++ b/kraken/ketos/pretrain.py @@ -133,7 +133,7 @@ @click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data. Overrides the `-p` parameter') -@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker processes.') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(0), help='Number of worker processes.') @click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') @click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, help='When loading an existing model, retrieve hyperparameters from the model') @@ -179,6 +179,7 @@ default=RECOGNITION_PRETRAIN_HYPER_PARAMS['logit_temp'], help='Multiplicative factor for the logits used in contrastive loss.') @click.argument('ground_truth', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False)) +@click.option('--legacy-polygons', show_default=True, default=False, is_flag=True, help='Use the legacy polygon extractor.') def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, min_epochs, lag, min_delta, device, precision, optimizer, lrate, momentum, weight_decay, warmup, schedule, gamma, step_size, sched_patience, @@ -186,7 +187,7 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, evaluation_files, workers, threads, load_hyper_parameters, repolygonize, force_binarization, format_type, augment, mask_probability, mask_width, num_negatives, logit_temp, - ground_truth): + ground_truth, legacy_polygons): """ Trains a model from image-text pairs. """ @@ -258,7 +259,8 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, output=output, spec=spec, model=load, - load_hyper_parameters=load_hyper_parameters) + load_hyper_parameters=load_hyper_parameters, + legacy_polygons=legacy_polygons) data_module = PretrainDataModule(batch_size=hyper_params.pop('batch_size'), pad=hyper_params.pop('pad'), @@ -273,7 +275,8 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, channels=model.channels, repolygonize=repolygonize, force_binarization=force_binarization, - format_type=format_type) + format_type=format_type, + legacy_polygons=legacy_polygons,) model.len_train_set = len(data_module.train_dataloader()) diff --git a/kraken/ketos/recognition.py b/kraken/ketos/recognition.py index 849408162..559d86d3b 100644 --- a/kraken/ketos/recognition.py +++ b/kraken/ketos/recognition.py @@ -21,6 +21,8 @@ import logging import pathlib from typing import List +from functools import partial +import warnings import click from threadpoolctl import threadpool_limits @@ -157,7 +159,7 @@ @click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data. Overrides the `-p` parameter') -@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker processes.') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(0), help='Number of worker processes.') @click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') @click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, help='When loading an existing model, retrieve hyperparameters from the model') @@ -190,6 +192,7 @@ @click.option('--log-dir', show_default=True, type=click.Path(exists=True, dir_okay=True, writable=True), help='Path to directory where the logger will store the logs. If not set, a directory will be created in the current working directory.') @click.argument('ground_truth', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False)) +@click.option('--legacy-polygons', show_default=True, default=False, is_flag=True, help='Use the legacy polygon extractor.') def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, min_epochs, lag, min_delta, device, precision, optimizer, lrate, momentum, weight_decay, warmup, freeze_backbone, schedule, gamma, step_size, @@ -197,7 +200,7 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, normalize_whitespace, codec, resize, reorder, base_dir, training_files, evaluation_files, workers, threads, load_hyper_parameters, repolygonize, force_binarization, format_type, augment, - pl_logger, log_dir, ground_truth): + pl_logger, log_dir, ground_truth, legacy_polygons): """ Trains a model from image-text pairs. """ @@ -300,7 +303,19 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, force_binarization=force_binarization, format_type=format_type, codec=codec, - resize=resize) + resize=resize, + legacy_polygons=legacy_polygons) + + # Force upgrade to new polygon extractor if model was not trained with it + if model.nn and model.nn.use_legacy_polygons: + if not legacy_polygons and not model.legacy_polygons: + # upgrade to new polygon extractor + logger.warning('The model will be flagged to use new polygon extractor.') + model.nn.use_legacy_polygons = False + if not model.nn and legacy_polygons != model.legacy_polygons: + logger.warning(f'Dataset was compiled with legacy polygon extractor: {model.legacy_polygons}, ' + f'the new model will be flagged to use {"legacy" if model.legacy_polygons else "new"} method.') + legacy_polygons = model.legacy_polygons trainer = KrakenTrainer(accelerator=accelerator, devices=device, @@ -349,7 +364,7 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, @click.option('--pad', show_default=True, type=click.INT, default=16, help='Left and right ' 'padding around lines') @click.option('--workers', show_default=True, default=1, - type=click.IntRange(1), + type=click.IntRange(0), help='Number of worker processes when running on CPU.') @click.option('--threads', show_default=True, default=1, type=click.IntRange(1), @@ -387,9 +402,10 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, @click.option('--fixed-splits/--ignore-fixed-split', show_default=True, default=False, help='Whether to honor fixed splits in binary datasets.') @click.argument('test_set', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False)) +@click.option('--no-legacy-polygons', show_default=True, default=False, is_flag=True, help='Force disable the legacy polygon extractor.') def test(ctx, batch_size, model, evaluation_files, device, pad, workers, threads, reorder, base_dir, normalization, normalize_whitespace, - repolygonize, force_binarization, format_type, fixed_splits, test_set): + repolygonize, force_binarization, format_type, fixed_splits, test_set, no_legacy_polygons): """ Evaluate on a test set. """ @@ -410,11 +426,28 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, logger.info('Building test set from {} line images'.format(len(test_set) + len(evaluation_files))) + legacy_polygons = None + incoherent_legacy_polygons = False + nn = {} for p in model: message('Loading model {}\t'.format(p), nl=False) nn[p] = models.load_any(p, device) message('\u2713', fg='green') + model_legacy_polygons = nn[p].nn.use_legacy_polygons + if legacy_polygons is None: + legacy_polygons = model_legacy_polygons + elif legacy_polygons != model_legacy_polygons: + incoherent_legacy_polygons = True + + if incoherent_legacy_polygons and not no_legacy_polygons: + logger.warning('Models use different polygon extractors. Legacy polygon extractor will be used ; use --no-legacy-polygons to force disable it.') + legacy_polygons = True + elif no_legacy_polygons: + legacy_polygons = False + + if legacy_polygons: + warnings.warn('Using legacy polygon extractor, as the model was not trained with the new method. Please retrain your model to get performance improvements.') pin_ds_mem = False if device != 'cpu': @@ -440,7 +473,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, message('Repolygonizing data') test_set = [{'page': XMLPage(file, filetype=format_type).to_container()} for file in test_set] valid_norm = False - DatasetClass = PolygonGTDataset + DatasetClass = partial(PolygonGTDataset, legacy_polygons=legacy_polygons) elif format_type == 'binary': DatasetClass = ArrowIPCRecognitionDataset if repolygonize: @@ -485,6 +518,13 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, ds.add(**line) except ValueError as e: logger.info(e) + + if hasattr(ds, 'legacy_polygon_status'): + if ds.legacy_polygons_status != legacy_polygons: + warnings.warn( + f'Binary dataset was compiled with legacy polygon extractor: {ds.legacy_polygon_status}, ' + f'while expecting data extracted with {"legacy" if legacy_polygons else "new"} method. Results may be inaccurate.') + # don't encode validation set as the alphabets may not match causing encoding failures ds.no_encode() ds_loader = DataLoader(ds, diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py index 9ff26d57e..33191d596 100644 --- a/kraken/ketos/ro.py +++ b/kraken/ketos/ro.py @@ -123,7 +123,7 @@ @click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data. Overrides the `-p` parameter') -@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker proesses.') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(0), help='Number of worker proesses.') @click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') @click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, help='When loading an existing model, retrieve hyper-parameters from the model') diff --git a/kraken/ketos/segmentation.py b/kraken/ketos/segmentation.py index 4d6cdfaeb..f1391e358 100644 --- a/kraken/ketos/segmentation.py +++ b/kraken/ketos/segmentation.py @@ -159,7 +159,7 @@ def _validate_merging(ctx, param, value): @click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data. Overrides the `-p` parameter') -@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker proesses.') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(0), help='Number of worker proesses.') @click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') @click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, help='When loading an existing model, retrieve hyper-parameters from the model') @@ -382,7 +382,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data.') @click.option('-d', '--device', show_default=True, default='cpu', help='Select device to use (cpu, cuda:0, cuda:1, ...)') -@click.option('--workers', default=1, show_default=True, type=click.IntRange(1), +@click.option('--workers', default=1, show_default=True, type=click.IntRange(0), help='Number of worker processes for data loading.') @click.option('--threads', default=1, show_default=True, type=click.IntRange(1), help='Size of thread pools for intra-op parallelization') diff --git a/kraken/kraken.py b/kraken/kraken.py index 9c33dbf29..61b653880 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -227,10 +227,12 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, if bounds.script_detection: it = rpred.mm_rpred(model, im, bounds, pad, bidi_reordering=bidi_reordering, - tags_ignore=tags_ignore) + tags_ignore=tags_ignore, + no_legacy_polygons=ctx.meta['no_legacy_polygons']) else: it = rpred.rpred(model['default'], im, bounds, pad, - bidi_reordering=bidi_reordering) + bidi_reordering=bidi_reordering, + no_legacy_polygons=ctx.meta['no_legacy_polygons']) preds = [] @@ -302,8 +304,10 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, help='On compatible devices, uses autocast for `segment` which lower the memory usage.') @click.option('--threads', default=1, show_default=True, type=click.IntRange(1), help='Size of thread pools for intra-op parallelization') +@click.option('--no-legacy-polygons', 'no_legacy_polygons', is_flag=True, default=False, + help="Force disable legacy polygon extraction") def cli(input, batch_input, suffix, verbose, format_type, pdf_format, - serializer, template, device, raise_on_error, autocast, threads): + serializer, template, device, raise_on_error, autocast, threads, no_legacy_polygons): """ Base command for recognition functionality. @@ -334,6 +338,8 @@ def cli(input, batch_input, suffix, verbose, format_type, pdf_format, ctx.meta['steps'] = [] ctx.meta["autocast"] = autocast ctx.meta['threads'] = threads + ctx.meta['no_legacy_polygons'] = no_legacy_polygons + log.set_logger(logger, level=30 - min(10 * verbose, 20)) diff --git a/kraken/lib/arrow_dataset.py b/kraken/lib/arrow_dataset.py index c9159a916..d3b4a9288 100755 --- a/kraken/lib/arrow_dataset.py +++ b/kraken/lib/arrow_dataset.py @@ -44,7 +44,7 @@ logger = logging.getLogger(__name__) -def _extract_line(xml_record, skip_empty_lines: bool = True): +def _extract_line(xml_record, skip_empty_lines: bool = True, legacy_polygons: bool = False): lines = [] try: im = Image.open(xml_record.imagename) @@ -62,7 +62,7 @@ def _extract_line(xml_record, skip_empty_lines: bool = True): script_detection=False, line_orders=[]) try: - line_im, line = next(extract_polygons(im, seg)) + line_im, line = next(extract_polygons(im, seg, legacy=legacy_polygons)) except KrakenInputException: logger.warning(f'Invalid line {idx} in {im.filename}') continue @@ -113,7 +113,8 @@ def build_binary_dataset(files: Optional[List[Union[str, 'PathLike', Dict]]] = N force_type: Optional[str] = None, recordbatch_size: int = 100, skip_empty_lines: bool = True, - callback: Callable[[int, int], None] = lambda chunk, lines: None) -> None: + callback: Callable[[int, int], None] = lambda chunk, lines: None, + legacy_polygons: bool = False) -> None: """ Parses XML files and dumps the baseline-style line images and text into a binary dataset. @@ -141,10 +142,11 @@ def build_binary_dataset(files: Optional[List[Union[str, 'PathLike', Dict]]] = N skip_empty_lines: Do not compile empty text lines into the dataset. callback: Function called every time a new recordbatch is flushed into the Arrow IPC file. + legacy_polygons: Use legacy polygon extraction code. """ logger.info('Parsing XML files') - extract_fn = partial(_extract_line, skip_empty_lines=skip_empty_lines) + extract_fn = partial(_extract_line, skip_empty_lines=skip_empty_lines, legacy_polygons=legacy_polygons) parse_fn = None if format_type in ['xml', 'alto', 'page']: parse_fn = XMLPage @@ -216,6 +218,7 @@ def build_binary_dataset(files: Optional[List[Union[str, 'PathLike', Dict]]] = N 'image_type': 'raw', 'splits': ['train', 'eval', 'test'], 'im_mode': '1', + 'legacy_polygons': legacy_polygons, 'counts': Counter({'all': 0, 'train': 0, 'validation': 0, @@ -309,6 +312,7 @@ def _make_record_batch(line_cache): f"image_type: {metadata['lines']['image_type']}\n" f"splits: {metadata['lines']['splits']}\n" f"im_mode: {metadata['lines']['im_mode']}\n" + f"legacy_polygons: {metadata['lines']['legacy_polygons']}\n" f"lines: {metadata['lines']['counts']}\n") with pa.memory_map(tmp_file, 'rb') as source: diff --git a/kraken/lib/dataset/recognition.py b/kraken/lib/dataset/recognition.py index a80f6cd89..fde975cc5 100644 --- a/kraken/lib/dataset/recognition.py +++ b/kraken/lib/dataset/recognition.py @@ -121,6 +121,7 @@ def __init__(self, self.arrow_table = None self.codec = None self.skip_empty_lines = skip_empty_lines + self.legacy_polygons_status = None self.seg_type = None # built text transformations @@ -174,6 +175,12 @@ def add(self, file: Union[str, 'PathLike']) -> None: if self.seg_type == 'bbox' and metadata['image_type'] == 'raw': self.transforms.valid_norm = True + legacy_polygons = metadata.get('legacy_polygons', True) + if self.legacy_polygons_status is None: + self.legacy_polygons_status = legacy_polygons + elif self.legacy_polygons_status != legacy_polygons: + self.legacy_polygons_status = "mixed" + self.alphabet.update(metadata['alphabet']) num_lines = metadata['counts'][self._split_filter] if self._split_filter else metadata['counts']['all'] if self._split_filter: @@ -284,7 +291,8 @@ def __init__(self, skip_empty_lines: bool = True, reorder: Union[bool, Literal['L', 'R']] = True, im_transforms: Callable[[Any], torch.Tensor] = transforms.Compose([]), - augmentation: bool = False) -> None: + augmentation: bool = False, + legacy_polygons: bool=False) -> None: """ Creates a dataset for a polygonal (baseline) transcription model. @@ -307,6 +315,7 @@ def __init__(self, self.aug = None self.skip_empty_lines = skip_empty_lines self.failed_samples = set() + self.legacy_polygons = legacy_polygons self.seg_type = 'baselines' # built text transformations @@ -424,8 +433,8 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: boundary=item[0][2])], script_detection=True, regions={}, - line_orders=[]) - )) + line_orders=[]), + legacy=self.legacy_polygons)) im = self.transforms(im) if im.shape[0] == 3: im_mode = 'RGB' diff --git a/kraken/lib/pretrain/model.py b/kraken/lib/pretrain/model.py index 4685e1851..68626cf49 100644 --- a/kraken/lib/pretrain/model.py +++ b/kraken/lib/pretrain/model.py @@ -32,6 +32,7 @@ import math import re from itertools import chain +from functools import partial from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union import numpy as np @@ -87,7 +88,8 @@ def __init__(self, force_binarization: bool = False, format_type: str = 'path', pad: int = 16, - augment: bool = default_specs.RECOGNITION_PRETRAIN_HYPER_PARAMS['augment']): + augment: bool = default_specs.RECOGNITION_PRETRAIN_HYPER_PARAMS['augment'], + legacy_polygons: bool = False): """ A LightningDataModule encapsulating text-less training data for unsupervised recognition model pretraining. @@ -106,6 +108,8 @@ def __init__(self, super().__init__() self.save_hyperparameters() + self.legacy_polygons = legacy_polygons + DatasetClass = GroundTruthDataset valid_norm = True if format_type in ['xml', 'page', 'alto']: @@ -117,7 +121,7 @@ def __init__(self, if binary_dataset_split: logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') binary_dataset_split = False - DatasetClass = PolygonGTDataset + DatasetClass = partial(PolygonGTDataset, legacy_polygons=legacy_polygons) valid_norm = False elif format_type == 'binary': DatasetClass = ArrowIPCRecognitionDataset @@ -147,7 +151,7 @@ def __init__(self, # format_type is None. Determine training type from length of training data entry elif not format_type: if training_data[0].type == 'baselines': - DatasetClass = PolygonGTDataset + DatasetClass = partial(PolygonGTDataset, legacy_polygons=legacy_polygons) valid_norm = False else: if force_binarization: @@ -205,6 +209,19 @@ def __init__(self, 'set. (Will disable alphabet mismatch detection.)') self.train_set, self.val_set = random_split(train_set, (train_len, val_len)) + if format_type == 'binary': + legacy_train_status = train_set.legacy_polygons_status + if val_set and val_set.legacy_polygons_status != legacy_train_status: + logger.warning( + f'Train and validation set have different legacy polygon status: {legacy_train_status} and {val_set.legacy_polygons_status}.' + 'Train set status prevails.') + if legacy_train_status == "mixed": + logger.warning('Mixed legacy polygon status in training dataset. Consider recompilation.') + legacy_train_status = False + if legacy_polygons != legacy_train_status: + logger.warning(f'Setting dataset legacy polygon status to {legacy_train_status} based on training set.') + self.legacy_polygons = legacy_train_status + if len(self.train_set) == 0 or len(self.val_set) == 0: raise ValueError('No valid training data was provided to the train ' 'command. Please add valid XML, line, or binary data.') @@ -255,7 +272,8 @@ def __init__(self, spec: str = default_specs.RECOGNITION_SPEC, model: Optional[Union['PathLike', str]] = None, load_hyper_parameters: bool = False, - len_train_set: int = -1): + len_train_set: int = -1, + legacy_polygons: bool = False): """ A LightningModule encapsulating the unsupervised pretraining setup for a text recognition model. @@ -273,10 +291,15 @@ def __init__(self, """ super().__init__() hyper_params_ = default_specs.RECOGNITION_PRETRAIN_HYPER_PARAMS + self.legacy_polygons = legacy_polygons + if model: logger.info(f'Loading existing model from {model} ') self.nn = vgsl.TorchVGSLModel.load_model(model) + # apply legacy polygon parameter + self.nn.use_legacy_polygons = legacy_polygons + if self.nn.model_type not in [None, 'recognition']: raise ValueError(f'Model {model} is of type {self.nn.model_type} while `recognition` is expected.') @@ -430,6 +453,7 @@ def setup(self, stage: Optional[str] = None): else: logger.info(f'Creating new model {self.spec}') self.nn = vgsl.TorchVGSLModel(self.spec) + self.nn.use_legacy_polygons = self.legacy_polygons # initialize weights self.nn.init_weights() diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 8c979da53..9fad17a60 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -18,15 +18,15 @@ import logging from collections import defaultdict from typing import (TYPE_CHECKING, Dict, List, Literal, Optional, Sequence, - Tuple, Union) + Tuple, Union, TypeVar, Any, Generator) import numpy as np import shapely.geometry as geom import torch import torch.nn.functional as F -from PIL import Image +from PIL import Image, ImageDraw from scipy.ndimage import (binary_erosion, distance_transform_cdt, - gaussian_filter, maximum_filter) + gaussian_filter, maximum_filter, affine_transform) from scipy.signal import convolve2d from scipy.spatial.distance import pdist, squareform from shapely.ops import nearest_points, unary_union @@ -38,15 +38,18 @@ subdivide_polygon) from skimage.morphology import skeletonize from skimage.transform import (AffineTransform, PiecewiseAffineTransform, - SimilarityTransform, warp) + warp) from kraken.lib import default_specs from kraken.lib.exceptions import KrakenInputException if TYPE_CHECKING: - from kraken.containers import Segmentation + from kraken.containers import Segmentation, BBoxLine, BaselineLine from kraken.lib.vgsl import TorchVGSLModel + +_T_pil_or_np = TypeVar('_T_pil_or_np', Image.Image, np.ndarray) + logger = logging.getLogger('kraken') __all__ = ['reading_order', @@ -356,8 +359,6 @@ def vectorize_regions(im: np.ndarray, threshold: float = 0.5): labelled = label(bin) boundaries = [] for x in regionprops(labelled): - if x.area < 32: - continue boundary = boundary_tracing(x) if len(boundary) > 2: boundaries.append(geom.Polygon(boundary)) @@ -371,19 +372,32 @@ def vectorize_regions(im: np.ndarray, threshold: float = 0.5): return [np.array(x.coords, dtype=np.uint)[:, [1, 0]].tolist() for x in boundaries] -def _rotate(image, angle, center, scale, cval=0): +def _rotate(image: _T_pil_or_np, + angle: float, + center: Any, + scale: float, + cval: int = 0, + order: int = 0) -> Tuple[AffineTransform, _T_pil_or_np]: """ - Rotate function taken mostly from scikit image. Main difference is that - this one allows dimensional scaling and records the final translation - to ensure no image content is lost. This is needed to rotate the seam - back into the original image. + Rotate an image at an angle with optional scaling + Args: + image (PIL.Image.Image or (H, W, C) np.ndarray): Input image + angle (float): Angle in radians + center (tuple): unused + scale (float): x-Axis scaling factor + cval (int): Padding value + order (int): Interpolation order + Returns: + A tuple containing the transformation matrix and the rotated image. + Note: this function is much faster applied on PIL images than on numpy ndarrays. """ - rows, cols = image.shape[0], image.shape[1] - tform1 = SimilarityTransform(translation=center) - tform2 = SimilarityTransform(rotation=angle) - tform3 = SimilarityTransform(translation=-center) - tform4 = AffineTransform(scale=(1/scale, 1)) - tform = tform4 + tform3 + tform2 + tform1 + if isinstance(image, Image.Image): + rows, cols = image.height, image.width + else: + rows, cols = image.shape[:2] + assert len(image.shape) == 3 or len(image.shape) == 2, 'Image must be 2D or 3D' + + tform = AffineTransform(rotation=angle, scale=(1/scale, 1)) corners = np.array([ [0, 0], [0, rows - 1], @@ -397,13 +411,25 @@ def _rotate(image, angle, center, scale, cval=0): maxr = corners[:, 1].max() out_rows = maxr - minr + 1 out_cols = maxc - minc + 1 - output_shape = np.around((out_rows, out_cols)) + output_shape = tuple(int(o) for o in np.around((out_rows, out_cols))) # fit output image in new shape - translation = (minc, minr) - tform5 = SimilarityTransform(translation=translation) - tform = tform5 + tform - tform.params[2] = (0, 0, 1) - return tform, warp(image, tform, output_shape=output_shape, order=0, cval=cval, clip=False, preserve_range=True) + translation = tform([[minc, minr]]) + tform = AffineTransform(rotation=angle, scale=(1/scale, 1), translation=[f for f in translation.flatten()]) + + if isinstance(image, Image.Image): + # PIL is much faster than scipy + pdata = tform.params.flatten().tolist()[:6] + resample = {0: Image.NEAREST, 1: Image.BILINEAR, 2: Image.BICUBIC, 3: Image.BICUBIC}.get(order, Image.NEAREST) + return tform, image.transform(output_shape[::-1], Image.AFFINE, data=pdata, resample=resample, fillcolor=cval) + + # params for scipy + # swap X and Y axis for scipy + pdata = tform.params.copy()[[1, 0, 2], :][:, [1, 0, 2]] + # we copy the translation vector + offset = pdata[:2, 2].copy() + # scipy expects a 3x3 *linear* matrix (to include channel axis), we don't want the channel axis to be modified + pdata[:2, 2] = 0 + return tform, affine_transform(image, pdata, offset=(*offset, 0), output_shape=(*output_shape, *image.shape[2:]), cval=cval, order=order) def line_regions(line, regions): @@ -458,7 +484,6 @@ def _calc_seam(baseline, polygon, angle, im_feats, bias=150): level. """ MASK_VAL = 99999 - r, c = draw.polygon(polygon[:, 1], polygon[:, 0]) c_min, c_max = int(polygon[:, 0].min()), int(polygon[:, 0].max()) r_min, r_max = int(polygon[:, 1].min()), int(polygon[:, 1].max()) patch = im_feats[r_min:r_max+2, c_min:c_max+2].copy() @@ -472,8 +497,7 @@ def _calc_seam(baseline, polygon, angle, im_feats, bias=150): mask[line_locs] = 0 dist_bias = distance_transform_cdt(mask) # absolute mask - mask = np.ones_like(patch, dtype=bool) - mask[r-r_min, c-c_min] = False + mask = np.array(make_polygonal_mask(polygon-(r_min, c_min)), patch.shape[1::-1]) > 128 # dilate mask to compensate for aliasing during rotation mask = binary_erosion(mask, border_value=True, iterations=2) # combine weights with features @@ -1025,7 +1049,99 @@ def compute_polygon_section(baseline: Sequence[Tuple[int, int]], return tuple(o) -def extract_polygons(im: Image.Image, bounds: 'Segmentation') -> Image.Image: +def _bevelled_warping_envelope(baseline: np.ndarray, + output_bl_start: Tuple[float, float], + output_shape: Tuple[int, int]) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]: + """ + Calculates the source and target envelope for a piecewise affine transform + """ + def _as_int_tuple(x): + return tuple(int(i) for i in x) + + envelope_dy = [-output_bl_start[1], output_shape[0] - output_bl_start[1]] + diff_bl = np.diff(baseline, axis=0) + diff_bl_normed = diff_bl / np.linalg.norm(diff_bl, axis=1)[:, None] + l_bl = len(baseline) + cum_lens = np.cumsum([0] + np.linalg.norm(diff_bl, axis=1).tolist()) + + bl_seg_normals = np.array([-diff_bl_normed[:, 1], diff_bl_normed[:, 0]]).T + ini_point = baseline[0] - diff_bl_normed[0] * output_bl_start[0] + source_envelope = [ + _as_int_tuple(ini_point + envelope_dy[0]*bl_seg_normals[0]), + _as_int_tuple(ini_point + envelope_dy[1]*bl_seg_normals[0]), + ] + target_envelope = [ + (0, 0), + (0, output_shape[0]) + ] + MAX_BEVEL_WIDTH = output_shape[0] / 3 + BEVEL_STEP_WIDTH = MAX_BEVEL_WIDTH / 2 + + for k in range(l_bl-2): + pt = baseline[k+1] + seg_prev = baseline[k] - pt + seg_next = baseline[k+2] - pt + bevel_prev = seg_prev / max(2., np.linalg.norm(seg_prev) / MAX_BEVEL_WIDTH) + bevel_next = seg_next / max(2., np.linalg.norm(seg_next) / MAX_BEVEL_WIDTH) + bevel_nsteps = max(1, np.round((np.linalg.norm(bevel_prev) + np.linalg.norm(bevel_next)) / BEVEL_STEP_WIDTH)) + l_prev = np.linalg.norm(bevel_prev) + l_next = np.linalg.norm(bevel_next) + for i in range(int(bevel_nsteps)+1): + # bezier interp + t = i / bevel_nsteps + tpt = pt + (1-t)**2 * bevel_prev + t**2 * bevel_next + tx = output_bl_start[0] + cum_lens[k+1] - (1-t)**2 * l_prev + t**2 * l_next + tnormal = (1-t) * bl_seg_normals[k] + t * bl_seg_normals[k+1] + tnormal /= np.linalg.norm(tnormal) + source_points = [_as_int_tuple(tpt + envelope_dy[0]*tnormal), _as_int_tuple(tpt + envelope_dy[1]*tnormal)] + target_points = [(int(tx), 0), (int(tx), output_shape[0])] + # avoid duplicate points leading to singularities + if source_points[0] == source_envelope[-2] or source_points[1] == source_envelope[-1] or target_points[0] == target_envelope[-2]: + continue + source_envelope += source_points + target_envelope += target_points + + end_point = baseline[-1] + diff_bl_normed[-1]*(output_shape[1]-cum_lens[-1]-output_bl_start[0]) + source_envelope += [ + end_point + envelope_dy[0]*bl_seg_normals[-1], + end_point + envelope_dy[1]*bl_seg_normals[-1], + ] + target_envelope += [ + (output_shape[1], 0), + (output_shape[1], output_shape[0]) + ] + return source_envelope, target_envelope + + +def make_polygonal_mask(polygon: np.ndarray, shape: Tuple[int, int]) -> Image.Image: + """ + Creates a mask from a polygon. + + Args: + polygon: A polygon as a list of points. + shape: The shape of the mask to create. + + Returns: + A PIL.Image.Image instance containing the mask. + """ + mask = Image.new('L', shape, 0) + ImageDraw.Draw(mask).polygon([tuple(p) for p in polygon.astype(int).tolist()], fill=255, width=2) + return mask + + +def apply_polygonal_mask(img: Image.Image, polygon: np.ndarray, cval: int = 0) -> Image.Image: + """ + Extract the polygonal mask of an image. + """ + mask = make_polygonal_mask(polygon, img.size) + out = Image.new(img.mode, (img.width, img.height), cval) + out.paste(img, mask=mask) + return out + + +def extract_polygons(im: Image.Image, + bounds: "Segmentation", + legacy: bool = False) -> Generator[Tuple[Image.Image, Union["BBoxLine", "BaselineLine"],], None, None]: """ Yields the subimages of image im defined in the list of bounding polygons with baselines preserving order. @@ -1034,9 +1150,10 @@ def extract_polygons(im: Image.Image, bounds: 'Segmentation') -> Image.Image: im: Input image bounds: A Segmentation class containing a bounding box or baseline segmentation. + legacy: Use the old, slow, and deprecated path Yields: - The extracted subimage + The extracted subimage, and the corresponding bounding box or baseline """ if bounds.type == 'baselines': # select proper interpolation scheme depending on shape @@ -1045,7 +1162,6 @@ def extract_polygons(im: Image.Image, bounds: 'Segmentation') -> Image.Image: im = im.convert('L') else: order = 1 - im = np.array(im) for line in bounds.lines: if line.boundary is None: @@ -1055,85 +1171,170 @@ def extract_polygons(im: Image.Image, bounds: 'Segmentation') -> Image.Image: c_min, c_max = int(pl[:, 0].min()), int(pl[:, 0].max()) r_min, r_max = int(pl[:, 1].min()), int(pl[:, 1].max()) - if (pl < 0).any() or (pl.max(axis=0)[::-1] >= im.shape[:2]).any(): + imshape = np.array([im.height, im.width]) + + if (pl < 0).any() or (pl.max(axis=0)[::-1] >= imshape).any(): raise KrakenInputException('Line polygon outside of image bounds') - if (baseline < 0).any() or (baseline.max(axis=0)[::-1] >= im.shape[:2]).any(): + if (baseline < 0).any() or (baseline.max(axis=0)[::-1] >= imshape).any(): raise KrakenInputException('Baseline outside of image bounds') - # fast path for straight baselines requiring only rotation - if len(baseline) == 2: - baseline = baseline.astype(float) - # calculate direction vector - lengths = np.linalg.norm(np.diff(baseline.T), axis=0) - p_dir = np.mean(np.diff(baseline.T) * lengths/lengths.sum(), axis=1) - p_dir = (p_dir.T / np.sqrt(np.sum(p_dir**2, axis=-1))) - angle = np.arctan2(p_dir[1], p_dir[0]) - patch = im[r_min:r_max+1, c_min:c_max+1].copy() - offset_polygon = pl - (c_min, r_min) - r, c = draw.polygon(offset_polygon[:, 1], offset_polygon[:, 0]) - mask = np.zeros(patch.shape[:2], dtype=bool) - mask[r, c] = True - patch[np.invert(mask)] = 0 - extrema = offset_polygon[(0, -1), :] - # scale line image to max 600 pixel width - tform, rotated_patch = _rotate(patch, angle, center=extrema[0], scale=1.0, cval=0) - i = Image.fromarray(rotated_patch.astype('uint8')) - # normal slow path with piecewise affine transformation - else: - if len(pl) > 50: - pl = approximate_polygon(pl, 2) - full_polygon = subdivide_polygon(pl, preserve_ends=True) - pl = geom.MultiPoint(full_polygon) - - bl = zip(baseline[:-1:], baseline[1::]) - bl = [geom.LineString(x) for x in bl] - cum_lens = np.cumsum([0] + [line.length for line in bl]) - # distance of intercept from start point and number of line segment - control_pts = [] - for point in pl.geoms: - npoint = np.array(point.coords)[0] - line_idx, dist, intercept = min(((idx, line.project(point), - np.array(line.interpolate(line.project(point)).coords)) for idx, line in enumerate(bl)), - key=lambda x: np.linalg.norm(npoint-x[2])) - # absolute distance from start of line - line_dist = cum_lens[line_idx] + dist - intercept = np.array(intercept) - # side of line the point is at - side = np.linalg.det(np.array([[baseline[line_idx+1][0]-baseline[line_idx][0], - npoint[0]-baseline[line_idx][0]], - [baseline[line_idx+1][1]-baseline[line_idx][1], - npoint[1]-baseline[line_idx][1]]])) - side = np.sign(side) - # signed perpendicular distance from the rectified distance - per_dist = side * np.linalg.norm(npoint-intercept) - control_pts.append((line_dist, per_dist)) - # calculate baseline destination points - bl_dst_pts = baseline[0] + np.dstack((cum_lens, np.zeros_like(cum_lens)))[0] - # calculate bounding polygon destination points - pol_dst_pts = np.array([baseline[0] + (line_dist, per_dist) for line_dist, per_dist in control_pts]) - # extract bounding box patch - c_dst_min, c_dst_max = int(pol_dst_pts[:, 0].min()), int(pol_dst_pts[:, 0].max()) - r_dst_min, r_dst_max = int(pol_dst_pts[:, 1].min()), int(pol_dst_pts[:, 1].max()) - output_shape = np.around((r_dst_max - r_dst_min + 1, c_dst_max - c_dst_min + 1)) - patch = im[r_min:r_max+1, c_min:c_max+1].copy() - # offset src points by patch shape - offset_polygon = full_polygon - (c_min, r_min) - offset_baseline = baseline - (c_min, r_min) - # offset dst point by dst polygon shape - offset_bl_dst_pts = bl_dst_pts - (c_dst_min, r_dst_min) - offset_pol_dst_pts = pol_dst_pts - (c_dst_min, r_dst_min) - # mask out points outside bounding polygon - mask = np.zeros(patch.shape[:2], dtype=bool) - r, c = draw.polygon(offset_polygon[:, 1], offset_polygon[:, 0]) - mask[r, c] = True - patch[np.invert(mask)] = 0 - # estimate piecewise transform - src_points = np.concatenate((offset_baseline, offset_polygon)) - dst_points = np.concatenate((offset_bl_dst_pts, offset_pol_dst_pts)) - tform = PiecewiseAffineTransform() - tform.estimate(src_points, dst_points) - o = warp(patch, tform.inverse, output_shape=output_shape, preserve_range=True, order=order) - i = Image.fromarray(o.astype('uint8')) + if legacy: + im = np.array(im) + # Old, slow, and deprecated path + # fast path for straight baselines requiring only rotation + if len(baseline) == 2: + baseline = baseline.astype(float) + # calculate direction vector + lengths = np.linalg.norm(np.diff(baseline.T), axis=0) + p_dir = np.mean(np.diff(baseline.T) * lengths/lengths.sum(), axis=1) + p_dir = (p_dir.T / np.sqrt(np.sum(p_dir**2, axis=-1))) + angle = np.arctan2(p_dir[1], p_dir[0]) + patch = im[r_min:r_max+1, c_min:c_max+1].copy() + offset_polygon = pl - (c_min, r_min) + r, c = draw.polygon(offset_polygon[:, 1], offset_polygon[:, 0]) + mask = np.zeros(patch.shape[:2], dtype=bool) + mask[r, c] = True + patch[np.invert(mask)] = 0 + extrema = offset_polygon[(0, -1), :] + # scale line image to max 600 pixel width + tform, rotated_patch = _rotate(patch, angle, center=extrema[0], scale=1.0, cval=0) + i = Image.fromarray(rotated_patch.astype('uint8')) + # normal slow path with piecewise affine transformation + else: + if len(pl) > 50: + pl = approximate_polygon(pl, 2) + full_polygon = subdivide_polygon(pl, preserve_ends=True) + pl = geom.MultiPoint(full_polygon) + + bl = zip(baseline[:-1:], baseline[1::]) + bl = [geom.LineString(x) for x in bl] + cum_lens = np.cumsum([0] + [line.length for line in bl]) + # distance of intercept from start point and number of line segment + control_pts = [] + for point in pl.geoms: + npoint = np.array(point.coords)[0] + line_idx, dist, intercept = min(((idx, line.project(point), + np.array(line.interpolate(line.project(point)).coords)) for idx, line in enumerate(bl)), + key=lambda x: np.linalg.norm(npoint-x[2])) + # absolute distance from start of line + line_dist = cum_lens[line_idx] + dist + intercept = np.array(intercept) + # side of line the point is at + side = np.linalg.det(np.array([[baseline[line_idx+1][0]-baseline[line_idx][0], + npoint[0]-baseline[line_idx][0]], + [baseline[line_idx+1][1]-baseline[line_idx][1], + npoint[1]-baseline[line_idx][1]]])) + side = np.sign(side) + # signed perpendicular distance from the rectified distance + per_dist = side * np.linalg.norm(npoint-intercept) + control_pts.append((line_dist, per_dist)) + # calculate baseline destination points + bl_dst_pts = baseline[0] + np.dstack((cum_lens, np.zeros_like(cum_lens)))[0] + # calculate bounding polygon destination points + pol_dst_pts = np.array([baseline[0] + (line_dist, per_dist) for line_dist, per_dist in control_pts]) + # extract bounding box patch + c_dst_min, c_dst_max = int(pol_dst_pts[:, 0].min()), int(pol_dst_pts[:, 0].max()) + r_dst_min, r_dst_max = int(pol_dst_pts[:, 1].min()), int(pol_dst_pts[:, 1].max()) + output_shape = np.around((r_dst_max - r_dst_min + 1, c_dst_max - c_dst_min + 1)) + patch = im[r_min:r_max+1, c_min:c_max+1].copy() + # offset src points by patch shape + offset_polygon = full_polygon - (c_min, r_min) + offset_baseline = baseline - (c_min, r_min) + # offset dst point by dst polygon shape + offset_bl_dst_pts = bl_dst_pts - (c_dst_min, r_dst_min) + offset_pol_dst_pts = pol_dst_pts - (c_dst_min, r_dst_min) + # mask out points outside bounding polygon + mask = np.zeros(patch.shape[:2], dtype=bool) + r, c = draw.polygon(offset_polygon[:, 1], offset_polygon[:, 0]) + mask[r, c] = True + patch[np.invert(mask)] = 0 + # estimate piecewise transform + src_points = np.concatenate((offset_baseline, offset_polygon)) + dst_points = np.concatenate((offset_bl_dst_pts, offset_pol_dst_pts)) + tform = PiecewiseAffineTransform() + tform.estimate(src_points, dst_points) + o = warp(patch, tform.inverse, output_shape=output_shape, preserve_range=True, order=order) + i = Image.fromarray(o.astype('uint8')) + + else: # if not legacy + # new, fast, and efficient path + # fast path for straight baselines requiring only rotation + if len(baseline) == 2: + baseline = baseline.astype(float) + # calculate direction vector + lengths = np.linalg.norm(np.diff(baseline.T), axis=0) + p_dir = np.mean(np.diff(baseline.T) * lengths/lengths.sum(), axis=1) + p_dir = (p_dir.T / np.sqrt(np.sum(p_dir**2, axis=-1))) + angle = np.arctan2(p_dir[1], p_dir[0]) + # crop out bounding box + patch = im.crop((c_min, r_min, c_max+1, r_max+1)) + offset_polygon = pl - (c_min, r_min) + patch = apply_polygonal_mask(patch, offset_polygon, cval=0) + extrema = offset_polygon[(0, -1), :] + tform, i = _rotate(patch, angle, center=extrema[0], scale=1.0, cval=0, order=order) + # normal slow path with piecewise affine transformation + else: + if len(pl) > 50: + pl = approximate_polygon(pl, 2) + full_polygon = subdivide_polygon(pl, preserve_ends=True) + + # baseline segment vectors + diff_bl = np.diff(baseline, axis=0) + diff_bl_norms = np.linalg.norm(diff_bl, axis=1) + diff_bl_normed = diff_bl / diff_bl_norms[:, None] + + l_poly = len(full_polygon) + cum_lens = np.cumsum([0] + np.linalg.norm(diff_bl, axis=1).tolist()) + + # calculate baseline destination points : + bl_dst_pts = baseline[0] + np.dstack((cum_lens, np.zeros_like(cum_lens)))[0] + + # calculate bounding polygon destination points : + # diff[k, p] = baseline[k] - polygon[p] + poly_bl_diff = full_polygon[None, :] - baseline[:-1, None] + # local x coordinates of polygon points on baseline segments + # x[k, p] = (baseline[k] - polygon[p]) . (baseline[k+1] - baseline[k]) / |baseline[k+1] - baseline[k]| + poly_bl_x = np.einsum('kpm,km->kp', poly_bl_diff, diff_bl_normed) + # distance to baseline segments + poly_bl_segdist = np.maximum(-poly_bl_x, poly_bl_x - diff_bl_norms[:, None]) + # closest baseline segment index + poly_closest_bl = np.argmin((poly_bl_segdist), axis=0) + poly_bl_x = poly_bl_x[poly_closest_bl, np.arange(l_poly)] + poly_bl_diff = poly_bl_diff[poly_closest_bl, np.arange(l_poly)] + # signed distance between polygon points and baseline segments (to get y coordinates) + poly_bl_y = np.cross(diff_bl_normed[poly_closest_bl], poly_bl_diff) + # final destination points + pol_dst_pts = np.array( + [cum_lens[poly_closest_bl] + poly_bl_x, poly_bl_y] + ).T + baseline[:1] + + # extract bounding box patch + c_dst_min, c_dst_max = int(pol_dst_pts[:, 0].min()), int(pol_dst_pts[:, 0].max()) + r_dst_min, r_dst_max = int(pol_dst_pts[:, 1].min()), int(pol_dst_pts[:, 1].max()) + output_shape = np.around((r_dst_max - r_dst_min + 1, c_dst_max - c_dst_min + 1)) + patch = im.crop((c_min, r_min, c_max+1, r_max+1)) + # offset src points by patch shape + offset_polygon = full_polygon - (c_min, r_min) + offset_baseline = baseline - (c_min, r_min) + # offset dst point by dst polygon shape + offset_bl_dst_pts = bl_dst_pts - (c_dst_min, r_dst_min) + # mask out points outside bounding polygon + patch = apply_polygonal_mask(patch, offset_polygon, cval=0) + + # estimate piecewise transform by beveling angles + source_envelope, target_envelope = _bevelled_warping_envelope(offset_baseline, offset_bl_dst_pts[0], output_shape) + # mesh for PIL, as (box, quad) tuples : box is (NW, SE) and quad is (NW, SW, SE, NE) + deform_mesh = [ + ( + (*target_envelope[i], *target_envelope[i+3]), + (*source_envelope[i], *source_envelope[i+1], *source_envelope[i+3], *source_envelope[i+2]) + ) + for i in range(0, len(source_envelope)-3, 2) + ] + # warp + resample = {0: Image.NEAREST, 1: Image.BILINEAR, 2: Image.BICUBIC, 3: Image.BICUBIC}.get(order, Image.NEAREST) + i = patch.transform((output_shape[1], output_shape[0]), Image.MESH, data=deform_mesh, resample=resample) + yield i.crop(i.getbbox()), line else: if bounds.text_direction.startswith('vertical'): diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 1ba45eb00..da1f4ff79 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -20,6 +20,7 @@ import warnings from typing import (TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, Sequence, Union) +from functools import partial import numpy as np import pytorch_lightning as pl @@ -216,7 +217,8 @@ def __init__(self, force_binarization: bool = False, format_type: Literal['path', 'alto', 'page', 'xml', 'binary'] = 'path', codec: Optional[Dict] = None, - resize: Literal['fail', 'both', 'new', 'add', 'union'] = 'fail'): + resize: Literal['fail', 'both', 'new', 'add', 'union'] = 'fail', + legacy_polygons: bool = False): """ A LightningModule encapsulating the training setup for a text recognition model. @@ -233,6 +235,7 @@ def __init__(self, **kwargs: Setup parameters, i.e. CLI parameters of the train() command. """ super().__init__() + self.legacy_polygons = legacy_polygons hyper_params_ = default_specs.RECOGNITION_HYPER_PARAMS.copy() if model: logger.info(f'Loading existing model from {model} ') @@ -284,7 +287,7 @@ def __init__(self, if binary_dataset_split: logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') binary_dataset_split = False - DatasetClass = PolygonGTDataset + DatasetClass = partial(PolygonGTDataset, legacy_polygons=legacy_polygons) valid_norm = False elif format_type == 'binary': DatasetClass = ArrowIPCRecognitionDataset @@ -314,7 +317,7 @@ def __init__(self, # format_type is None. Determine training type from container class types elif not format_type: if training_data[0].type == 'baselines': - DatasetClass = PolygonGTDataset + DatasetClass = partial(PolygonGTDataset, legacy_polygons=legacy_polygons) valid_norm = False else: if force_binarization: @@ -375,6 +378,7 @@ def __init__(self, logger.debug('Setting multiprocessing tensor sharing strategy to file_system') torch.multiprocessing.set_sharing_strategy('file_system') + val_set = None if evaluation_data: train_set = self._build_dataset(DatasetClass, training_data) self.train_set = Subset(train_set, range(len(train_set))) @@ -399,6 +403,19 @@ def __init__(self, raise ValueError('No valid training data was provided to the train ' 'command. Please add valid XML, line, or binary data.') + if format_type == 'binary': + legacy_train_status = train_set.legacy_polygons_status + if val_set and val_set.legacy_polygons_status != legacy_train_status: + logger.warning( + f'Train and validation set have different legacy polygon status: {legacy_train_status} and {val_set.legacy_polygons_status}.' + 'Train set status prevails.') + if legacy_train_status == "mixed": + logger.warning('Mixed legacy polygon status in training dataset. Consider recompilation.') + legacy_train_status = False + if legacy_polygons != legacy_train_status: + logger.warning(f'Setting dataset legacy polygon status to {legacy_train_status} based on training set.') + self.legacy_polygons = legacy_train_status + logger.info(f'Training set {len(self.train_set)} lines, validation set ' f'{len(self.val_set)} lines, alphabet {len(train_set.alphabet)} ' 'symbols') @@ -592,6 +609,7 @@ def setup(self, stage: Optional[str] = None): logger.info(f'Creating new model {self.spec} with {self.train_set.dataset.codec.max_label+1} outputs') self.spec = '[{} O1c{}]'.format(self.spec[1:-1], self.train_set.dataset.codec.max_label + 1) self.nn = vgsl.TorchVGSLModel(self.spec) + self.nn.use_legacy_polygons = self.legacy_polygons # initialize weights self.nn.init_weights() self.nn.add_codec(self.train_set.dataset.codec) diff --git a/kraken/lib/vgsl.py b/kraken/lib/vgsl.py index 9d1cfc6cc..93a7d0cb3 100644 --- a/kraken/lib/vgsl.py +++ b/kraken/lib/vgsl.py @@ -140,7 +140,8 @@ def __init__(self, spec: str) -> None: 'seg_type': None, 'one_channel_mode': None, 'model_type': None, - 'hyper_params': {}} + 'hyper_params': {}, + 'legacy_polygons': False} # enable new polygons by default on new models self._aux_layers = nn.ModuleDict() self.idx = -1 @@ -311,7 +312,8 @@ def _deserialize_layers(name, layer): 'seg_type': 'bbox', 'one_channel_mode': '1', 'model_type': None, - 'hyper_params': {}} + 'hyper_params': {}, + 'legacy_polygons': True} # disable new polygons by default on load if 'kraken_meta' in mlmodel.user_defined_metadata: nn.user_metadata.update(json.loads(mlmodel.user_defined_metadata['kraken_meta'])) @@ -363,6 +365,14 @@ def aux_layers(self, **kwargs): def aux_layers(self, val: Dict[str, torch.nn.Module]): self._aux_layers.update(val) + @property + def use_legacy_polygons(self): + return self.user_metadata.get('legacy_polygons', True) + + @use_legacy_polygons.setter + def use_legacy_polygons(self, val: bool): + self.user_metadata['legacy_polygons'] = val + def save_model(self, path: str): """ Serializes the model into path. diff --git a/kraken/rpred.py b/kraken/rpred.py index 96dff4fdd..d41a11f67 100644 --- a/kraken/rpred.py +++ b/kraken/rpred.py @@ -24,6 +24,7 @@ from functools import partial from typing import (TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple, Union) +import warnings from kraken.containers import BaselineOCRRecord, BBoxOCRRecord, ocr_record from kraken.lib.dataset import ImageInputTransforms @@ -52,7 +53,8 @@ def __init__(self, bounds: 'Segmentation', pad: int = 16, bidi_reordering: Union[bool, str] = True, - tags_ignore: Optional[List[Tuple[str, str]]] = None) -> Generator[ocr_record, None, None]: + tags_ignore: Optional[List[Tuple[str, str]]] = None, + no_legacy_polygons: bool = False) -> Generator[ocr_record, None, None]: """ Multi-model version of kraken.rpred.rpred. @@ -159,6 +161,7 @@ def __init__(self, self.pad = pad self.bounds = bounds self.tags_ignore = tags_ignore + self.no_legacy_polygons = no_legacy_polygons def _recognize_box_line(self, line): xmin, ymin, xmax, ymax = line.bbox @@ -175,8 +178,10 @@ def _recognize_box_line(self, line): tag, net = self._resolve_tags_to_model(line.tags, self.nets) + use_legacy_polygons = self._choose_legacy_polygon_extractor(net) + seg = dataclasses.replace(self.bounds, lines=[line]) - box, coords = next(extract_polygons(self.im, seg)) + box, coords = next(extract_polygons(self.im, seg, legacy=use_legacy_polygons)) self.box = box # check if boxes are non-zero in any dimension @@ -242,14 +247,17 @@ def _recognize_baseline_line(self, line): seg = dataclasses.replace(self.bounds, lines=[line]) + tag, net = self._resolve_tags_to_model(line.tags, self.nets) + + use_legacy_polygons = self._choose_legacy_polygon_extractor(net) + try: - box, coords = next(extract_polygons(self.im, seg)) + box, coords = next(extract_polygons(self.im, seg, legacy=use_legacy_polygons)) except KrakenInputException as e: logger.warning(f'Extracting line failed: {e}') return BaselineOCRRecord('', [], [], line) self.box = box - tag, net = self._resolve_tags_to_model(line.tags, self.nets) # check if boxes are non-zero in any dimension if 0 in box.size: logger.warning(f'{line} with zero dimension. Emitting empty record.') @@ -299,13 +307,26 @@ def __len__(self): def _scale_val(self, val, min_val, max_val): return int(round(min(max(((val*self.net_scale)-self.pad)*self.in_scale, min_val), max_val-1))) + + def _choose_legacy_polygon_extractor(self, net) -> bool: + # grouping the checks here to display warnings only once + if net.nn.use_legacy_polygons: + if self.no_legacy_polygons: + warnings.warn('Enforcing use of the new polygon extractor for models trained with old version. Accuracy may be affected.') + return False + else: + warnings.warn('Using legacy polygon extractor, as the model was not trained with the new method. Please retrain your model to get speed improvement.') + return True + return False + def rpred(network: 'TorchSeqRecognizer', im: 'Image.Image', bounds: 'Segmentation', pad: int = 16, - bidi_reordering: Union[bool, str] = True) -> Generator[ocr_record, None, None]: + bidi_reordering: Union[bool, str] = True, + no_legacy_polygons: bool = False) -> Generator[ocr_record, None, None]: """ Uses a TorchSeqRecognizer and a segmentation to recognize text @@ -325,7 +346,7 @@ def rpred(network: 'TorchSeqRecognizer', An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. """ - return mm_rpred(defaultdict(lambda: network), im, bounds, pad, bidi_reordering) + return mm_rpred(defaultdict(lambda: network), im, bounds, pad, bidi_reordering, no_legacy_polygons=no_legacy_polygons) def _resolve_tags_to_model(tags: Optional[Sequence[Dict[str, str]]], diff --git a/tests/resources/170025120000003,0074-lite.xml b/tests/resources/170025120000003,0074-lite.xml new file mode 100644 index 000000000..504794e0f --- /dev/null +++ b/tests/resources/170025120000003,0074-lite.xml @@ -0,0 +1,89 @@ + + + + TRP + 2016-06-16T16:57:15.027+02:00 + 2018-07-04T17:25:44.389+02:00 + + + + + + + + + + + + + + + + + $pag:39 + + + + $pag:39 + + + + + + + + + $-nor su hijo, De todos sus bienes, con los pactos + + + + + + + y salvedades alli expressadas; Y fue acetada; + + + + + + + y assi mismo el dho$.dicho $ofi:Patron $ant:Miguel $ant:Carreras, + + + + + + + y el $ofi:Rndo$.Reverendo $ant:Miguel $ant:Carreras $ofi:pbro$.presbítero residente en la + + + + $-nor su hijo, De todos sus bienes, con los pactos +y salvedades alli expressadas; Y fue acetada; +y assi mismo el dho$.dicho $ofi:Patron $ant:Miguel $ant:Carreras, +y el $ofi:Rndo$.Reverendo $ant:Miguel $ant:Carreras $ofi:pbro$.presbítero residente en la +Parq.$^l$.Parroquial Igla$.Iglesia de dha$.dicha villa de $top:Canet Padre é, hijo +hizieron donacion a la dha$.dicha $ant:Anna $ant:Maria su +hija y hermana resp.$^e$.respectivamente por todos sus drôs$.derechos de le:$- +$-gitima Paterna, Materna y otros de ducientas$.doscientas +libras de moneda Bar$.barcelonesa; arca y vestidos corres:$- +$-pondientes, con promesa de pagar en esta +forma, ésto es arcas, ropas y joyas el dia de las +Bodas; cien libras del dia de la fecha, á, medio +año y las restantes cien libras del dho$.dicho dia de la +fecha á tres años prox.$^s$.proximos venturos bajo obli:$- +$-gacion de todos sus bienes; cuya dha$.dicha donacion +fue echa con el pacto revercional acos:$- +$-tumbrado; Y fue azetada por la dha$.dicha $ant:Anna +$ant:Maria por quien fue echa la diffinición cor:$- +$-respondiente de dhos$.dichos sus dros$.derechos a favor del dho$.dicho +su Padre y hermano resp.$^e$.respectivamente y salvose el de +futura sucession: Y en su consequencia hizo +la correspondiente constitucion dotal al +dho$.dicho $ant:Joseph $ant:Vancells su venidero esposo; y este +acetandola prometió en su caso restituir +bajo obligacion de todos sus bienes. + + + + diff --git a/tests/resources/overfit_newpoly.mlmodel b/tests/resources/overfit_newpoly.mlmodel new file mode 100644 index 000000000..cb7b40aa9 Binary files /dev/null and b/tests/resources/overfit_newpoly.mlmodel differ diff --git a/tests/test_newpolygons.py b/tests/test_newpolygons.py new file mode 100644 index 000000000..ffa06d0fb --- /dev/null +++ b/tests/test_newpolygons.py @@ -0,0 +1,449 @@ +# -*- coding: utf-8 -*- + +from contextlib import contextmanager +import unittest +import tempfile +from unittest.mock import Mock, patch +from pathlib import Path +from traceback import print_exception +import warnings +from typing import Optional, List + +from PIL import Image + +from click.testing import CliRunner + +from kraken.containers import ( + BaselineLine, + BaselineOCRRecord, + BBoxLine, + BBoxOCRRecord, + Segmentation, +) +from kraken.lib import xml +from kraken.lib import segmentation +from kraken.lib.models import load_any +from kraken.rpred import mm_rpred, rpred +from kraken.kraken import cli as kraken_cli +from kraken.ketos import cli as ketos_cli +import re + +thisfile = Path(__file__).resolve().parent +resources = thisfile / "resources" + +def mock_extract_polygons(): + return Mock(side_effect=segmentation.extract_polygons) + +class TestNewPolygons(unittest.TestCase): + """ + Tests for the new polygon extraction method. + """ + + def setUp(self): + self.im = Image.open(resources / "bw.png") + self.old_model_path = str(resources / "overfit.mlmodel") + self.old_model = load_any(self.old_model_path) + self.new_model_path = str(resources / "overfit_newpoly.mlmodel") + self.new_model = load_any(self.new_model_path) + self.segmented_img = str(resources / "170025120000003,0074-lite.xml") + self.runner = CliRunner() + self.color_img = resources / "input.tif" + self.arrow_data = str(resources / "merge_tests/base.arrow") + self.simple_bl_seg = Segmentation( + type="baselines", + imagename=resources / "bw.png", + lines=[ + BaselineLine( + id="foo", + baseline=[[0, 10], [2543, 10]], + boundary=[[0, 0], [2543, 0], [2543, 155], [0, 155]], + ) + ], + text_direction="horizontal-lr", + script_detection=False, + ) + + ## RECIPES + + @patch("kraken.rpred.extract_polygons", new_callable=mock_extract_polygons) + def _test_rpred(self, extractor_mock: Mock, *, model, force_no_legacy: bool=False, expect_legacy: bool): + """ + Base recipe for testing rpred with a given model and polygon extraction method + """ + pred = rpred(model, self.im, self.simple_bl_seg, True, no_legacy_polygons=force_no_legacy) + _ = next(pred) + + extractor_mock.assert_called() + for cl in extractor_mock.mock_calls: + self.assertEqual(cl[2]["legacy"], expect_legacy) + + @patch("kraken.rpred.extract_polygons", new_callable=mock_extract_polygons) + def _test_krakencli(self, extractor_mock: Mock, *, args, force_no_legacy: bool=False, expect_legacy: bool,): + """ + Base recipe for testing kraken_cli with a given polygon extraction method + """ + if force_no_legacy: + args = ["--no-legacy-polygons"] + args + + result = self.runner.invoke(kraken_cli, args) + print("kraken", *args) + + if result.exception: + print_exception(result.exception) + + self.assertEqual(result.exit_code, 0) + extractor_mock.assert_called() + for cl in extractor_mock.mock_calls: + self.assertEqual(cl[2]["legacy"], expect_legacy) + + def _test_ketoscli(self, *, args, expect_legacy: bool, check_exit_code: Optional[int|List[int]]=0, patching_dir="kraken.lib.dataset.recognition"): + """ + Base recipe for testing ketos_cli with a given polygon extraction method + """ + with patch(patching_dir + ".extract_polygons", new_callable=mock_extract_polygons) as extractor_mock: + result = self.runner.invoke(ketos_cli, args) + + print("ketos", *args) + if result.exception: + print(result.output) + print_exception(result.exception) + + if check_exit_code is not None: + if isinstance(check_exit_code, int): + check_exit_code = [check_exit_code] + self.assertIn(result.exit_code, check_exit_code, "Command failed") + + extractor_mock.assert_called() + for cl in extractor_mock.mock_calls: + self.assertEqual(cl[2]["legacy"], expect_legacy) + + ## TESTS + + def test_rpred_from_old_model(self): + """ + Test rpred with old model, check that it uses legacy polygon extraction method + """ + self._test_rpred(model=self.old_model, force_no_legacy=False, expect_legacy=True) + + def test_rpred_from_old_model_force_new(self): + """ + Test rpred with old model, but disabling legacy polygons + """ + self._test_rpred(model=self.old_model, force_no_legacy=True, expect_legacy=False) + + def test_rpred_from_new_model(self): + """ + Test rpred with new model, check that it uses new polygon extraction method + """ + self._test_rpred(model=self.new_model, force_no_legacy=False, expect_legacy=False) + + + def test_krakencli_ocr_old_model(self): + """ + Test kraken_cli with old model, check that it uses legacy polygon extraction method + """ + with tempfile.NamedTemporaryFile() as fp: + self._test_krakencli( + args=['-f', 'xml', '-i', self.segmented_img, fp.name, 'ocr', '-m', self.old_model_path], + force_no_legacy=False, + expect_legacy=True, + ) + + def test_krakencli_ocr_old_model_force_new(self): + """ + Test kraken_cli with old model, check that it uses legacy polygon extraction method + """ + with tempfile.NamedTemporaryFile() as fp: + self._test_krakencli( + args=['-f', 'xml', '-i', self.segmented_img, fp.name, 'ocr', '-m', self.old_model_path], + force_no_legacy=True, + expect_legacy=False, + ) + + def test_krakencli_ocr_new_model(self): + """ + Test kraken_cli with new model, check that it uses new polygon extraction method + """ + with tempfile.NamedTemporaryFile() as fp: + self._test_krakencli( + args=['-f', 'xml', '-i', self.segmented_img, fp.name, 'ocr', '-m', self.new_model_path], + force_no_legacy=False, + expect_legacy=False, + ) + + + + def test_ketoscli_test_old_model(self): + """ + Test `ketos test` with old model, check that it uses legacy polygon extraction method + """ + self._test_ketoscli( + args=['test', '-m', self.old_model_path, '-f', 'xml', '--workers', '0', self.segmented_img], + expect_legacy=True, + ) + + def test_ketoscli_test_old_model_force_new(self): + """ + Test `ketos test` with old model, check that it does not use legacy polygon extraction method + """ + self._test_ketoscli( + args=['test', '--no-legacy-polygons', '-m', self.old_model_path, '-f', 'xml', '--workers', '0', self.segmented_img], + expect_legacy=False, + ) + + def test_ketoscli_test_new_model(self): + """ + Test `ketos test` with new model, check that it uses new polygon extraction method + """ + self._test_ketoscli( + args=['test', '-m', self.new_model_path, '-f', 'xml', '--workers', '0', self.segmented_img], + expect_legacy=False, + ) + + + def test_ketoscli_train_new_model(self): + """ + Test `ketos train` with new model, check that it uses new polygon extraction method + """ + with tempfile.TemporaryDirectory() as tempdir: + mfp = str(Path(tempdir) / "model") + fp = str(Path(tempdir) / "test.xml") + + self._test_ketoscli( + args=['train', '-f', 'xml', '-N', '1', '-q', 'fixed', '-o', mfp, '--workers', '0', self.segmented_img], + expect_legacy=False, + check_exit_code=[0, 1], # Model may not improve during training + ) + + self._test_krakencli( + args=['-f', 'xml', '-i', self.segmented_img, fp, 'ocr', '-m', mfp + "_0.mlmodel"], + expect_legacy=False, + ) + + def test_ketoscli_train_new_model_force_legacy(self): + """ + Test `ketos train` training new model, check that it uses legacy polygon extraction method if forced + """ + with tempfile.TemporaryDirectory() as tempdir: + mfp = str(Path(tempdir) / "model") + fp = str(Path(tempdir) / "test.xml") + + self._test_ketoscli( + args=['train', '--legacy-polygons', '-f', 'xml', '-N', '1', '-q', 'fixed', '-o', mfp, '--workers', '0', self.segmented_img], + expect_legacy=True, + check_exit_code=[0, 1], # Model may not improve during training + ) + + self._test_krakencli( + args=['-f', 'xml', '-i', self.segmented_img, fp, 'ocr', '-m', mfp + "_0.mlmodel"], + expect_legacy=True, + ) + + def test_ketoscli_train_old_model(self): + """ + Test `ketos train` finetuning old model, check that it uses new polygon extraction method + """ + with tempfile.TemporaryDirectory() as tempdir: + mfp = str(Path(tempdir) / "model") + fp = str(Path(tempdir) / "test.xml") + + self._test_ketoscli( + args=['train', '-f', 'xml', '-N', '1', '-q', 'fixed', '-i', self.old_model_path, '--resize', 'add', '-o', mfp, '--workers', '0', self.segmented_img], + expect_legacy=False, + check_exit_code=[0, 1], # Model may not improve during training + ) + self._test_krakencli( + args=['-f', 'xml', '-i', self.segmented_img, fp, 'ocr', '-m', mfp + "_0.mlmodel"], + expect_legacy=False, + ) + + def test_ketoscli_train_old_model_force_legacy(self): + """ + Test `ketos train` finetuning old model, check that it uses legacy polygon extraction method if forced + """ + with tempfile.TemporaryDirectory() as tempdir: + mfp = str(Path(tempdir) / "model") + fp = str(Path(tempdir) / "test.xml") + + self._test_ketoscli( + args=['train', '--legacy-polygons', '-f', 'xml', '-N', '1', '-q', 'fixed', '-i', self.old_model_path, '--resize', 'add', '-o', mfp, '--workers', '0', self.segmented_img], + expect_legacy=True, + check_exit_code=[0, 1], # Model may not improve during training + ) + self._test_krakencli( + args=['-f', 'xml', '-i', self.segmented_img, fp, 'ocr', '-m', mfp + "_0.mlmodel"], + expect_legacy=True, + ) + + + @unittest.expectedFailure + def test_ketoscli_pretrain_new_model(self): + """ + Test `ketos pretrain` with new model, check that it uses new polygon extraction method + """ + with tempfile.TemporaryDirectory() as tempdir: + mfp = str(Path(tempdir) / "model") + fp = str(Path(tempdir) / "test.xml") + + self._test_ketoscli( + args=['pretrain', '-f', 'xml', '-N', '1', '-q', 'fixed', '-o', mfp, '--workers', '0', self.segmented_img], + expect_legacy=False, + check_exit_code=[0, 1], # Model may not improve during training + ) + self._test_krakencli( + args=['-f', 'xml', '-i', self.segmented_img, fp, 'ocr', '-m', mfp + "_0.mlmodel"], + expect_legacy=False, + ) + + @unittest.expectedFailure + def test_ketoscli_pretrain_new_model_force_legacy(self): + """ + Test `ketos pretrain` with new model, check that it uses legacy polygon extraction method if forced + """ + with tempfile.TemporaryDirectory() as tempdir: + mfp = str(Path(tempdir) / "model") + fp = str(Path(tempdir) / "test.xml") + + self._test_ketoscli( + args=['pretrain', '--legacy-polygons', '-f', 'xml', '-N', '1', '-q', 'fixed', '-o', mfp, '--workers', '0', self.segmented_img], + expect_legacy=True, + check_exit_code=[0, 1], # Model may not improve during training + ) + + self._test_krakencli( + args=['-f', 'xml', '-i', self.segmented_img, fp, 'ocr', '-m', str(mfp) + "_0.mlmodel"], + expect_legacy=True, + ) + + @unittest.expectedFailure + def test_ketoscli_pretrain_old_model(self): + """ + Test `ketos pretrain` with old model, check that it uses new polygon extraction method + """ + with tempfile.TemporaryDirectory() as tempdir: + mfp = str(Path(tempdir) / "model") + fp = str(Path(tempdir) / "test.xml") + + self._test_ketoscli( + args=['pretrain', '-f', 'xml', '-N', '1', '-q', 'fixed', '-i', self.old_model_path, '--resize', 'add', '-o', mfp, '--workers', '0', self.segmented_img], + expect_legacy=False, + check_exit_code=[0, 1], # Model may not improve during training + ) + + self._test_krakencli( + args=['-f', 'xml', '-i', self.segmented_img, fp, 'ocr', '-m', mfp + "_0.mlmodel"], + expect_legacy=False, + ) + + + def _assertWarnsWhenTrainingArrow( + self, model: str, *dset: str, from_model: str|None=None, force_legacy: bool=False, + expect_warning_msgs: list[str]=[], expect_not_warning_msgs: list[str]=[]): + + args = ['-f', 'binary', '-N', '1', '-q', 'fixed', '-o', model, *dset] + if force_legacy: + args = ['--legacy-polygons'] + args + if from_model: + args = ['-i', from_model, '--resize', 'add'] + args + + print("ketos", 'train', *args) + run = self.runner.invoke(ketos_cli, ['train'] + args) + output = re.sub(r'\w+\.py:\d+\n', '', run.output) + output = re.sub(r'\s+', ' ', output) + for warning_msg in expect_warning_msgs: + self.assertIn(warning_msg, output, f"Expected warning '{warning_msg}' not found in output") + for warning_msg in expect_not_warning_msgs: + self.assertNotIn(warning_msg, output, f"Unexpected warning '{warning_msg}' found in output") + + def test_ketos_old_arrow_train_new(self): + """ + Test `ketos train`, on old arrow dataset, check that it raises a warning about polygon extraction method only if incoherent + """ + with tempfile.TemporaryDirectory() as tempdir: + mfp = str(Path(tempdir) / "model") + mfp2 = str(Path(tempdir) / "model2") + + self._assertWarnsWhenTrainingArrow(mfp, self.arrow_data, force_legacy=False, expect_warning_msgs=["WARNING Setting dataset legacy polygon status to True based on training set", "the new model will be flagged to use legacy"]) + self._assertWarnsWhenTrainingArrow(mfp2, self.arrow_data, force_legacy=True, expect_not_warning_msgs=["WARNING Setting dataset legacy polygon status to True based on training set", "the new model will be flagged to use legacy"]) + + def test_ketos_new_arrow(self): + """ + Test `ketos compile`, check that it uses new polygon extraction method + """ + with tempfile.TemporaryDirectory() as tempdir: + dset = str(Path(tempdir) / "dataset.arrow") + mfp = str(Path(tempdir) / "model") + mfp2 = str(Path(tempdir) / "model2") + + self._test_ketoscli( + args=['compile', '-f', 'xml', '-o', dset, self.segmented_img], + expect_legacy=False, + patching_dir="kraken.lib.arrow_dataset", + ) + + self._assertWarnsWhenTrainingArrow(mfp, dset, force_legacy=False, expect_not_warning_msgs=["WARNING Setting dataset legacy polygon status to False based on training set", "the new model will be flagged to use legacy"]) + self._assertWarnsWhenTrainingArrow(mfp2, dset, force_legacy=True, expect_warning_msgs=["WARNING Setting dataset legacy polygon status to False based on training set", "the new model will be flagged to use new"]) + + + def test_ketos_new_arrow_force_legacy(self): + """ + Test `ketos compile`, check that it uses old polygon extraction method + """ + with tempfile.TemporaryDirectory() as tempdir: + dset = str(Path(tempdir) / "dataset.arrow") + mfp = str(Path(tempdir) / "model") + mfp2 = str(Path(tempdir) / "model2") + + self._test_ketoscli( + args=['compile', '--legacy-polygons', '-f', 'xml', '-o', dset, self.segmented_img], + expect_legacy=True, + patching_dir="kraken.lib.arrow_dataset", + ) + + self._assertWarnsWhenTrainingArrow(mfp, dset, force_legacy=False, expect_warning_msgs=["WARNING Setting dataset legacy polygon status to True based on training set", "the new model will be flagged to use legacy"]) + self._assertWarnsWhenTrainingArrow(mfp2, dset, force_legacy=True, expect_not_warning_msgs=["WARNING Setting dataset legacy polygon status to True based on training set", "the new model will be flagged to use legacy"]) + + def test_ketos_old_arrow_old_model(self): + """ + Test `ketos train`, on old arrow dataset, check that it raises a warning about polygon extraction method only if incoherent + """ + with tempfile.TemporaryDirectory() as tempdir: + mfp = str(Path(tempdir) / "model") + mfp2 = str(Path(tempdir) / "model2") + + self._assertWarnsWhenTrainingArrow(mfp, self.arrow_data, from_model=self.old_model_path, force_legacy=False, expect_warning_msgs=["WARNING Setting dataset legacy polygon status to True based on training set"], expect_not_warning_msgs=["model will be flagged to use new"]) + self._assertWarnsWhenTrainingArrow(mfp2, self.arrow_data, from_model=self.old_model_path, force_legacy=True, expect_not_warning_msgs=["WARNING Setting dataset legacy polygon status to True based on training set", "model will be flagged to use new"]) + + def test_ketos_new_arrow_old_model(self): + """ + Test `ketos train`, on new arrow dataset, check that it raises a warning about polygon extraction method only if incoherent + """ + with tempfile.TemporaryDirectory() as tempdir: + dset = str(Path(tempdir) / "dataset.arrow") + mfp = str(Path(tempdir) / "model") + mfp2 = str(Path(tempdir) / "model2") + + self._test_ketoscli( + args=['compile', '-f', 'xml', '-o', dset, self.segmented_img], + expect_legacy=False, + patching_dir="kraken.lib.arrow_dataset", + ) + + self._assertWarnsWhenTrainingArrow(mfp, dset, from_model=self.old_model_path, force_legacy=False, expect_not_warning_msgs=["WARNING Setting dataset legacy polygon status to False based on training set"], expect_warning_msgs=["model will be flagged to use new"]) + self._assertWarnsWhenTrainingArrow(mfp2, dset, from_model=self.old_model_path, force_legacy=True, expect_warning_msgs=["WARNING Setting dataset legacy polygon status to False based on training set"], expect_not_warning_msgs=["model will be flagged to use new"]) + + def test_ketos_mixed_arrow_train_new(self): + """ + Test `ketos train`, on mixed arrow dataset, check that it raises a warning about polygon extraction method only if incoherent + """ + with tempfile.TemporaryDirectory() as tempdir: + dset = str(Path(tempdir) / "dataset.arrow") + mfp = str(Path(tempdir) / "model") + + self._test_ketoscli( + args=['compile', '-f', 'xml', '-o', dset, self.segmented_img, self.arrow_data], + expect_legacy=False, + patching_dir="kraken.lib.arrow_dataset", + ) + + self._assertWarnsWhenTrainingArrow(mfp, dset, self.arrow_data, force_legacy=True, expect_warning_msgs=["WARNING Mixed legacy polygon", "WARNING Setting dataset legacy polygon status to False based on training set"], expect_not_warning_msgs=["model will be flagged to use legacy"]) \ No newline at end of file