From 5b9e2a09396a2538ccec8f3f848bf58fcada6403 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 27 Sep 2024 12:36:49 +0200 Subject: [PATCH 01/16] Compute pixel metrics only on regions --- kraken/lib/train.py | 53 +++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 92081bec..77460764 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -901,38 +901,41 @@ def validation_step(self, batch, batch_idx): x, y = batch['image'], batch['target'] pred, _ = self.nn.nn(x) # scale target to output size - y = F.interpolate(y, size=(pred.size(2), pred.size(3))).int() - - self.val_px_accuracy.update(pred, y) - self.val_mean_accuracy.update(pred, y) - self.val_mean_iu.update(pred, y) - self.val_freq_iu.update(pred, y) + y = F.interpolate(y, size=(pred.size(2), pred.size(3)), mode='nearest').int() + # Get regions for IoU metrics + reg_idxs = sorted(self.nn.user_metadata['class_mapping']['regions'].values()) + pred_reg = [:, reg_idxs, ...] + y_reg = y[:, reg_idxs, ...] + self.val_region_px_accuracy.update(pred_reg, y_reg) + self.val_region_mean_accuracy.update(pred_reg, y_reg) + self.val_region_mean_iu.update(pred_reg, y_reg) + self.val_region_freq_iu.update(pred_reg, y_reg) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: - pixel_accuracy = self.val_px_accuracy.compute() - mean_accuracy = self.val_mean_accuracy.compute() - mean_iu = self.val_mean_iu.compute() - freq_iu = self.val_freq_iu.compute() + pixel_accuracy = self.val_region_px_accuracy.compute() + mean_accuracy = self.val_region_mean_accuracy.compute() + mean_iu = self.val_region_mean_iu.compute() + freq_iu = self.val_region_freq_iu.compute() if mean_iu > self.best_metric: - logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') + logger.debug(f'Updating best region metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') self.best_epoch = self.current_epoch self.best_metric = mean_iu logger.info(f'validation run: accuracy {pixel_accuracy} mean_acc {mean_accuracy} mean_iu {mean_iu} freq_iu {freq_iu}') - self.log('val_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_region_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_region_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_region_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_region_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) # reset metrics even if sanity checking - self.val_px_accuracy.reset() - self.val_mean_accuracy.reset() - self.val_mean_iu.reset() - self.val_freq_iu.reset() + self.val_region_px_accuracy.reset() + self.val_region_mean_accuracy.reset() + self.val_region_mean_iu.reset() + self.val_region_freq_iu.reset() def setup(self, stage: Optional[str] = None): # finalize models in case of appending/loading @@ -1055,10 +1058,14 @@ def setup(self, stage: Optional[str] = None): torch.set_num_threads(max(self.num_workers, 1)) # set up validation metrics after output classes have been determined - self.val_px_accuracy = MultilabelAccuracy(average='micro', num_labels=self.train_set.dataset.num_classes) - self.val_mean_accuracy = MultilabelAccuracy(average='macro', num_labels=self.train_set.dataset.num_classes) - self.val_mean_iu = MultilabelJaccardIndex(average='macro', num_labels=self.train_set.dataset.num_classes) - self.val_freq_iu = MultilabelJaccardIndex(average='weighted', num_labels=self.train_set.dataset.num_classes) + # baseline metrics + # region metrics + num_regions = len(self.val_set.dataset.class_mapping['regions']) + self.val_region_px_accuracy = MultilabelAccuracy(average='micro', num_labels=num_regions) + self.val_region_mean_accuracy = MultilabelAccuracy(average='macro', num_labels=num_regions) + self.val_region_mean_iu = MultilabelJaccardIndex(average='macro', num_labels=num_regions) + self.val_region_freq_iu = MultilabelJaccardIndex(average='weighted', num_labels=num_regions) + def train_dataloader(self): return DataLoader(self.train_set, From 8ff556ec7360c1c0421ab398547bdd1e21db96c8 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 27 Sep 2024 18:01:30 +0200 Subject: [PATCH 02/16] refactor training module --- kraken/ketos/segmentation.py | 2 +- kraken/lib/train.py | 1197 ------------------------------ kraken/lib/train/__init__.py | 20 + kraken/lib/train/recognition.py | 552 ++++++++++++++ kraken/lib/train/segmentation.py | 480 ++++++++++++ kraken/lib/train/trainer.py | 166 +++++ kraken/lib/train/utils.py | 93 +++ 7 files changed, 1312 insertions(+), 1198 deletions(-) delete mode 100644 kraken/lib/train.py create mode 100644 kraken/lib/train/__init__.py create mode 100644 kraken/lib/train/recognition.py create mode 100644 kraken/lib/train/segmentation.py create mode 100644 kraken/lib/train/trainer.py create mode 100644 kraken/lib/train/utils.py diff --git a/kraken/ketos/segmentation.py b/kraken/ketos/segmentation.py index a98a4f66..e8da48d2 100644 --- a/kraken/ketos/segmentation.py +++ b/kraken/ketos/segmentation.py @@ -444,7 +444,7 @@ def segtest(ctx, model, evaluation_files, device, workers, threads, threshold, from torch.utils.data import DataLoader from kraken.lib.progress import KrakenProgressBar - from kraken.lib.train import BaselineSet, ImageInputTransforms + from kraken.lib.dataset import BaselineSet, ImageInputTransforms from kraken.lib.vgsl import TorchVGSLModel logger.info('Building test set from {} documents'.format(len(test_set) + len(evaluation_files))) diff --git a/kraken/lib/train.py b/kraken/lib/train.py deleted file mode 100644 index 77460764..00000000 --- a/kraken/lib/train.py +++ /dev/null @@ -1,1197 +0,0 @@ -# -# Copyright 2015 Benjamin Kiessling -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -""" -Training loop interception helpers -""" -import logging -import re -import warnings -from typing import (TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, - Sequence, Union) -from functools import partial - -import numpy as np -import lightning as L -import torch -import torch.nn.functional as F -from lightning.pytorch.callbacks import (BaseFinetuning, Callback, - EarlyStopping, LearningRateMonitor) -from torch.optim import lr_scheduler -from torch.utils.data import DataLoader, Subset, random_split -from torchmetrics.classification import (MultilabelAccuracy, - MultilabelJaccardIndex) -from torchmetrics.text import CharErrorRate, WordErrorRate - -from kraken.containers import Segmentation -from kraken.lib import default_specs, models, progress, vgsl -from kraken.lib.codec import PytorchCodec -from kraken.lib.dataset import (ArrowIPCRecognitionDataset, BaselineSet, - GroundTruthDataset, ImageInputTransforms, - PolygonGTDataset, collate_sequences) -from kraken.lib.exceptions import KrakenEncodeException, KrakenInputException -from kraken.lib.models import validate_hyper_parameters -from kraken.lib.util import make_printable, parse_gt_path -from kraken.lib.xml import XMLPage - -if TYPE_CHECKING: - from os import PathLike - -logger = logging.getLogger(__name__) - - -def _star_fun(fun, kwargs): - try: - return fun(**kwargs) - except FileNotFoundError as e: - logger.warning(f'{e.strerror}: {e.filename}. Skipping.') - except KrakenInputException as e: - logger.warning(str(e)) - return None - - -def _validation_worker_init_fn(worker_id): - """ Fix random seeds so that augmentation always produces the same - results when validating. Temporarily increase the logging level - for lightning because otherwise it will display a message - at info level about the seed being changed. """ - from lightning.pytorch import seed_everything - seed_everything(42) - - -class KrakenTrainer(L.Trainer): - def __init__(self, - enable_progress_bar: bool = True, - enable_summary: bool = True, - min_epochs: int = 5, - max_epochs: int = 100, - freeze_backbone=-1, - pl_logger: Union[L.pytorch.loggers.logger.Logger, str, None] = None, - log_dir: Optional['PathLike'] = None, - *args, - **kwargs): - kwargs['enable_checkpointing'] = False - kwargs['enable_progress_bar'] = enable_progress_bar - kwargs['min_epochs'] = min_epochs - kwargs['max_epochs'] = max_epochs - kwargs['callbacks'] = ([] if 'callbacks' not in kwargs else kwargs['callbacks']) - if not isinstance(kwargs['callbacks'], list): - kwargs['callbacks'] = [kwargs['callbacks']] - - if pl_logger: - if 'logger' in kwargs and isinstance(kwargs['logger'], L.pytorch.loggers.logger.Logger): - logger.debug('Experiment logger has been provided outside KrakenTrainer as `logger`') - elif isinstance(pl_logger, L.pytorch.loggers.logger.Logger): - logger.debug('Experiment logger has been provided outside KrakenTrainer as `pl_logger`') - kwargs['logger'] = pl_logger - elif pl_logger == 'tensorboard': - logger.debug('Creating default experiment logger') - kwargs['logger'] = L.pytorch.loggers.TensorBoardLogger(log_dir) - else: - logger.error('`pl_logger` was set, but %s is not an accepted value', pl_logger) - raise ValueError(f'{pl_logger} is not acceptable as logger') - kwargs['callbacks'].append(LearningRateMonitor(logging_interval='step')) - else: - kwargs['logger'] = False - - if enable_progress_bar: - progress_bar_cb = progress.KrakenTrainProgressBar(leave=True) - kwargs['callbacks'].append(progress_bar_cb) - - if enable_summary: - from lightning.pytorch.callbacks import RichModelSummary - summary_cb = RichModelSummary(max_depth=2) - kwargs['callbacks'].append(summary_cb) - kwargs['enable_model_summary'] = False - - if freeze_backbone > 0: - kwargs['callbacks'].append(KrakenFreezeBackbone(freeze_backbone)) - - kwargs['callbacks'].extend([KrakenSetOneChannelMode(), KrakenSaveModel()]) - super().__init__(*args, **kwargs) - self.automatic_optimization = False - - def fit(self, *args, **kwargs): - with warnings.catch_warnings(): - warnings.filterwarnings(action='ignore', category=UserWarning, - message='The dataloader,') - super().fit(*args, **kwargs) - - -class KrakenFreezeBackbone(BaseFinetuning): - """ - Callback freezing all but the last layer for fixed number of iterations. - """ - def __init__(self, unfreeze_at_iterations=10): - super().__init__() - self.unfreeze_at_iteration = unfreeze_at_iterations - - def freeze_before_training(self, pl_module): - pass - - def finetune_function(self, pl_module, current_epoch, optimizer): - pass - - def on_train_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: - self.freeze(pl_module.net[:-1]) - - def on_train_batch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch, batch_idx) -> None: - """ - Called for each training batch. - """ - if trainer.global_step == self.unfreeze_at_iteration: - for opt_idx, optimizer in enumerate(trainer.optimizers): - num_param_groups = len(optimizer.param_groups) - self.unfreeze_and_add_param_group(modules=pl_module.net[:-1], - optimizer=optimizer, - train_bn=True,) - current_param_groups = optimizer.param_groups - self._store(pl_module, opt_idx, num_param_groups, current_param_groups) - - def on_train_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: - """Called when the epoch begins.""" - pass - - -class KrakenSetOneChannelMode(Callback): - """ - Callback that sets the one_channel_mode of the model after the first epoch. - """ - def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: - # fill one_channel_mode after 1 iteration over training data set - if not trainer.sanity_checking and trainer.current_epoch == 0 and trainer.model.nn.model_type == 'recognition': - ds = getattr(pl_module, 'train_set', None) - if not ds and trainer.datamodule: - ds = trainer.datamodule.train_set - im_mode = ds.dataset.im_mode - if im_mode in ['1', 'L']: - logger.info(f'Setting model one_channel_mode to {im_mode}.') - trainer.model.nn.one_channel_mode = im_mode - - -class KrakenSaveModel(Callback): - """ - Kraken's own serialization callback instead of pytorch's. - """ - def on_validation_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: - if not trainer.sanity_checking: - trainer.model.nn.hyper_params['completed_epochs'] += 1 - metric = float(trainer.logged_metrics['val_metric']) if 'val_metric' in trainer.logged_metrics else -1.0 - trainer.model.nn.user_metadata['accuracy'].append((trainer.global_step, metric)) - trainer.model.nn.user_metadata['metrics'].append((trainer.global_step, {k: float(v) for k, v in trainer.logged_metrics.items()})) - - logger.info('Saving to {}_{}.mlmodel'.format(trainer.model.output, trainer.current_epoch)) - trainer.model.nn.save_model(f'{trainer.model.output}_{trainer.current_epoch}.mlmodel') - trainer.model.best_model = f'{trainer.model.output}_{trainer.model.best_epoch}.mlmodel' - - -class RecognitionModel(L.LightningModule): - def __init__(self, - hyper_params: Dict[str, Any] = None, - output: str = 'model', - spec: str = default_specs.RECOGNITION_SPEC, - append: Optional[int] = None, - model: Optional[Union['PathLike', str]] = None, - reorder: Union[bool, str] = True, - training_data: Union[Sequence[Union['PathLike', str]], Sequence[Dict[str, Any]]] = None, - evaluation_data: Optional[Union[Sequence[Union['PathLike', str]], Sequence[Dict[str, Any]]]] = None, - partition: Optional[float] = 0.9, - binary_dataset_split: bool = False, - num_workers: int = 1, - load_hyper_parameters: bool = False, - force_binarization: bool = False, - format_type: Literal['path', 'alto', 'page', 'xml', 'binary'] = 'path', - codec: Optional[Dict] = None, - resize: Literal['fail', 'both', 'new', 'add', 'union'] = 'fail', - legacy_polygons: bool = False): - """ - A LightningModule encapsulating the training setup for a text - recognition model. - - Setup parameters (load, training_data, evaluation_data, ....) are - named, model hyperparameters (everything in - `kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS`) are in in the - `hyper_params` argument. - - Args: - hyper_params (dict): Hyperparameter dictionary containing all fields - from - kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS - **kwargs: Setup parameters, i.e. CLI parameters of the train() command. - """ - super().__init__() - self.legacy_polygons = legacy_polygons - hyper_params_ = default_specs.RECOGNITION_HYPER_PARAMS.copy() - if model: - logger.info(f'Loading existing model from {model} ') - self.nn = vgsl.TorchVGSLModel.load_model(model) - - if self.nn.model_type not in [None, 'recognition']: - raise ValueError(f'Model {model} is of type {self.nn.model_type} while `recognition` is expected.') - - if load_hyper_parameters: - hp = self.nn.hyper_params - else: - hp = {} - hyper_params_.update(hp) - else: - self.nn = None - - if hyper_params: - hyper_params_.update(hyper_params) - self.hyper_params = hyper_params_ - self.save_hyperparameters() - - self.reorder = reorder - self.append = append - self.model = model - self.num_workers = num_workers - if resize == "add": - resize = "union" - warnings.warn("'add' value for resize has been deprecated. Use 'union' instead.", DeprecationWarning) - elif resize == "both": - resize = "new" - warnings.warn("'both' value for resize has been deprecated. Use 'new' instead.", DeprecationWarning) - - self.resize = resize - self.format_type = format_type - self.output = output - - self.best_epoch = -1 - self.best_metric = 0.0 - self.best_model = None - - DatasetClass = GroundTruthDataset - valid_norm = True - if format_type in ['xml', 'page', 'alto']: - logger.info(f'Parsing {len(training_data)} XML files for training data') - training_data = [{'page': XMLPage(file, format_type).to_container()} for file in training_data] - if evaluation_data: - logger.info(f'Parsing {len(evaluation_data)} XML files for validation data') - evaluation_data = [{'page': XMLPage(file, format_type).to_container()} for file in evaluation_data] - if binary_dataset_split: - logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') - binary_dataset_split = False - DatasetClass = partial(PolygonGTDataset, legacy_polygons=legacy_polygons) - valid_norm = False - elif format_type == 'binary': - DatasetClass = ArrowIPCRecognitionDataset - valid_norm = False - logger.info(f'Got {len(training_data)} binary dataset files for training data') - training_data = [{'file': file} for file in training_data] - if evaluation_data: - logger.info(f'Got {len(evaluation_data)} binary dataset files for validation data') - evaluation_data = [{'file': file} for file in evaluation_data] - elif format_type == 'path': - if force_binarization: - logger.warning('Forced binarization enabled in `path` mode. Will be ignored.') - force_binarization = False - if binary_dataset_split: - logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') - binary_dataset_split = False - logger.info(f'Got {len(training_data)} line strip images for training data') - training_data = [{'line': parse_gt_path(im)} for im in training_data] - if evaluation_data: - logger.info(f'Got {len(evaluation_data)} line strip images for validation data') - evaluation_data = [{'line': parse_gt_path(im)} for im in evaluation_data] - valid_norm = True - # format_type is None. Determine training type from container class types - elif not format_type: - if training_data[0].type == 'baselines': - DatasetClass = partial(PolygonGTDataset, legacy_polygons=legacy_polygons) - valid_norm = False - else: - if force_binarization: - logger.warning('Forced binarization enabled with box lines. Will be ignored.') - force_binarization = False - if binary_dataset_split: - logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') - binary_dataset_split = False - samples = [] - for sample in training_data: - if isinstance(sample, Segmentation): - samples.append({'page': sample}) - else: - samples.append({'line': sample}) - training_data = samples - if evaluation_data: - samples = [] - for sample in evaluation_data: - if isinstance(sample, Segmentation): - samples.append({'page': sample}) - else: - samples.append({'line': sample}) - evaluation_data = samples - else: - raise ValueError(f'format_type {format_type} not in [alto, page, xml, path, binary].') - - spec = spec.strip() - if spec[0] != '[' or spec[-1] != ']': - raise ValueError(f'VGSL spec {spec} not bracketed') - self.spec = spec - # preparse input sizes from vgsl string to seed ground truth data set - # sizes and dimension ordering. - if not self.nn: - blocks = spec[1:-1].split(' ') - m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) - if not m: - raise ValueError(f'Invalid input spec {blocks[0]}') - batch, height, width, channels = [int(x) for x in m.groups()] - else: - batch, channels, height, width = self.nn.input - - self.transforms = ImageInputTransforms(batch, - height, - width, - channels, - (self.hparams.hyper_params['pad'], 0), - valid_norm, - force_binarization) - - self.example_input_array = torch.Tensor(batch, - channels, - height if height else 32, - width if width else 400) - - if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): - logger.debug('Setting multiprocessing tensor sharing strategy to file_system') - torch.multiprocessing.set_sharing_strategy('file_system') - - val_set = None - if evaluation_data: - train_set = self._build_dataset(DatasetClass, training_data) - self.train_set = Subset(train_set, range(len(train_set))) - val_set = self._build_dataset(DatasetClass, evaluation_data) - self.val_set = Subset(val_set, range(len(val_set))) - elif binary_dataset_split: - train_set = self._build_dataset(DatasetClass, training_data, split_filter='train') - self.train_set = Subset(train_set, range(len(train_set))) - val_set = self._build_dataset(DatasetClass, training_data, split_filter='validation') - self.val_set = Subset(val_set, range(len(val_set))) - logger.info(f'Found {len(self.train_set)} (train) / {len(self.val_set)} (val) samples in pre-encoded dataset') - else: - train_set = self._build_dataset(DatasetClass, training_data) - train_len = int(len(train_set)*partition) - val_len = len(train_set) - train_len - logger.info(f'No explicit validation data provided. Splitting off ' - f'{val_len} (of {len(train_set)}) samples to validation ' - 'set. (Will disable alphabet mismatch detection.)') - self.train_set, self.val_set = random_split(train_set, (train_len, val_len)) - - if len(self.train_set) == 0 or len(self.val_set) == 0: - raise ValueError('No valid training data was provided to the train ' - 'command. Please add valid XML, line, or binary data.') - - if format_type == 'binary': - legacy_train_status = self.train_set.dataset.legacy_polygons_status - if self.val_set.dataset.legacy_polygons_status != legacy_train_status: - logger.warning('Train and validation set have different legacy ' - f'polygon status: {legacy_train_status} and ' - f'{self.val_set.dataset.legacy_polygons_status}. Train set ' - 'status prevails.') - if legacy_train_status == "mixed": - logger.warning('Mixed legacy polygon status in training dataset. Consider recompilation.') - legacy_train_status = False - if legacy_polygons != legacy_train_status: - logger.warning(f'Setting dataset legacy polygon status to {legacy_train_status} based on training set.') - self.legacy_polygons = legacy_train_status - - logger.info(f'Training set {len(self.train_set)} lines, validation set ' - f'{len(self.val_set)} lines, alphabet {len(train_set.alphabet)} ' - 'symbols') - alpha_diff_only_train = set(self.train_set.dataset.alphabet).difference(set(self.val_set.dataset.alphabet)) - alpha_diff_only_val = set(self.val_set.dataset.alphabet).difference(set(self.train_set.dataset.alphabet)) - if alpha_diff_only_train: - logger.warning(f'alphabet mismatch: chars in training set only: ' - f'{alpha_diff_only_train} (not included in accuracy test ' - 'during training)') - if alpha_diff_only_val: - logger.warning(f'alphabet mismatch: chars in validation set only: {alpha_diff_only_val} (not trained)') - logger.info('grapheme\tcount') - for k, v in sorted(train_set.alphabet.items(), key=lambda x: x[1], reverse=True): - char = make_printable(k) - if char == k: - char = '\t' + char - logger.info(f'{char}\t{v}') - - if codec: - logger.info('Instantiating codec') - self.codec = PytorchCodec(codec) - for k, v in self.codec.c2l.items(): - char = make_printable(k) - if char == k: - char = '\t' + char - logger.info(f'{char}\t{v}') - else: - self.codec = None - - logger.info('Encoding training set') - - self.val_cer = CharErrorRate() - self.val_wer = WordErrorRate() - - def _build_dataset(self, - DatasetClass, - training_data, - **kwargs): - dataset = DatasetClass(normalization=self.hparams.hyper_params['normalization'], - whitespace_normalization=self.hparams.hyper_params['normalize_whitespace'], - reorder=self.reorder, - im_transforms=self.transforms, - augmentation=self.hparams.hyper_params['augment'], - **kwargs) - - for sample in training_data: - try: - dataset.add(**sample) - except KrakenInputException as e: - logger.warning(str(e)) - if self.format_type == 'binary' and (self.hparams.hyper_params['normalization'] or - self.hparams.hyper_params['normalize_whitespace'] or - self.reorder): - logger.debug('Text transformations modifying alphabet selected. Rebuilding alphabet') - dataset.rebuild_alphabet() - - return dataset - - def forward(self, x, seq_lens=None): - return self.net(x, seq_lens) - - def training_step(self, batch, batch_idx): - input, target = batch['image'], batch['target'] - # sequence batch - if 'seq_lens' in batch: - seq_lens, label_lens = batch['seq_lens'], batch['target_lens'] - target = (target, label_lens) - o = self.net(input, seq_lens) - else: - o = self.net(input) - - seq_lens = o[1] - output = o[0] - target_lens = target[1] - target = target[0] - # height should be 1 by now - if output.size(2) != 1: - raise KrakenInputException('Expected dimension 3 to be 1, actual {}'.format(output.size(2))) - output = output.squeeze(2) - # NCW -> WNC - loss = self.nn.criterion(output.permute(2, 0, 1), # type: ignore - target, - seq_lens, - target_lens) - self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True) - return loss - - def validation_step(self, batch, batch_idx): - pred = self.rec_nn.predict_string(batch['image'], batch['seq_lens']) - idx = 0 - decoded_targets = [] - for offset in batch['target_lens']: - decoded_targets.append(''.join([x[0] for x in self.val_codec.decode([(x, 0, 0, 0) for x in batch['target'][idx:idx+offset]])])) - idx += offset - self.val_cer.update(pred, decoded_targets) - self.val_wer.update(pred, decoded_targets) - - if self.logger and self.trainer.state.stage != 'sanity_check' and self.hparams.hyper_params["batch_size"] * batch_idx < 16: - for i in range(self.hparams.hyper_params["batch_size"]): - count = self.hparams.hyper_params["batch_size"] * batch_idx + i - if count < 16: - self.logger.experiment.add_image(f'Validation #{count}, target: {decoded_targets[i]}', - batch['image'][i], - self.global_step, - dataformats="CHW") - self.logger.experiment.add_text(f'Validation #{count}, target: {decoded_targets[i]}', - pred[i], - self.global_step) - - def on_validation_epoch_end(self): - if not self.trainer.sanity_checking: - accuracy = 1.0 - self.val_cer.compute() - word_accuracy = 1.0 - self.val_wer.compute() - - if accuracy > self.best_metric: - logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {accuracy} ({self.current_epoch})') - self.best_epoch = self.current_epoch - self.best_metric = accuracy - logger.info(f'validation run: total chars {self.val_cer.total} errors {self.val_cer.errors} accuracy {accuracy}') - self.log('val_accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_word_accuracy', word_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_metric', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) - # reset metrics even if not sanity checking - self.val_cer.reset() - self.val_wer.reset() - - def setup(self, stage: Optional[str] = None): - # finalize models in case of appending/loading - if stage in [None, 'fit']: - - # Log a few sample images before the datasets are encoded. - # This is only possible for Arrow datasets, because the - # other dataset types can only be accessed after encoding - if self.logger and isinstance(self.train_set.dataset, ArrowIPCRecognitionDataset): - for i in range(min(len(self.train_set), 16)): - idx = np.random.randint(len(self.train_set)) - sample = self.train_set[idx] - self.logger.experiment.add_image(f'train_set sample #{i}: {sample["target"]}', sample['image']) - - if self.append: - self.train_set.dataset.encode(self.codec) - # now we can create a new model - self.spec = '[{} O1c{}]'.format(self.spec[1:-1], self.train_set.dataset.codec.max_label + 1) - logger.info(f'Appending {self.spec} to existing model {self.nn.spec} after {self.append}') - self.nn.append(self.append, self.spec) - self.nn.add_codec(self.train_set.dataset.codec) - logger.info(f'Assembled model spec: {self.nn.spec}') - elif self.model: - self.spec = self.nn.spec - - # prefer explicitly given codec over network codec if mode is 'new' - codec = self.codec if (self.codec and self.resize == 'new') else self.nn.codec - - codec.strict = True - - try: - self.train_set.dataset.encode(codec) - except KrakenEncodeException: - alpha_diff = set(self.train_set.dataset.alphabet).difference( - set(codec.c2l.keys()) - ) - if self.resize == 'fail': - raise KrakenInputException(f'Training data and model codec alphabets mismatch: {alpha_diff}') - elif self.resize == 'union': - logger.info(f'Resizing codec to include ' - f'{len(alpha_diff)} new code points') - # Construct two codecs: - # 1. training codec containing only the vocabulary in the training dataset - # 2. validation codec = training codec + validation set vocabulary - # This keep the codec in the model from being 'polluted' by non-trained characters. - train_codec = codec.add_labels(alpha_diff) - self.nn.add_codec(train_codec) - logger.info(f'Resizing last layer in network to {train_codec.max_label+1} outputs') - self.nn.resize_output(train_codec.max_label + 1) - self.train_set.dataset.encode(train_codec) - elif self.resize == 'new': - logger.info(f'Resizing network or given codec to ' - f'{len(self.train_set.dataset.alphabet)} ' - f'code sequences') - # same codec procedure as above, just with merging. - self.train_set.dataset.encode(None) - train_codec, del_labels = codec.merge(self.train_set.dataset.codec) - # Switch codec. - self.nn.add_codec(train_codec) - logger.info(f'Deleting {len(del_labels)} output classes from network ' - f'({len(codec)-len(del_labels)} retained)') - self.nn.resize_output(train_codec.max_label + 1, del_labels) - self.train_set.dataset.encode(train_codec) - else: - raise ValueError(f'invalid resize parameter value {self.resize}') - self.nn.codec.strict = False - self.spec = self.nn.spec - else: - self.train_set.dataset.encode(self.codec) - logger.info(f'Creating new model {self.spec} with {self.train_set.dataset.codec.max_label+1} outputs') - self.spec = '[{} O1c{}]'.format(self.spec[1:-1], self.train_set.dataset.codec.max_label + 1) - self.nn = vgsl.TorchVGSLModel(self.spec) - self.nn.use_legacy_polygons = self.legacy_polygons - # initialize weights - self.nn.init_weights() - self.nn.add_codec(self.train_set.dataset.codec) - - val_diff = set(self.val_set.dataset.alphabet).difference( - set(self.train_set.dataset.codec.c2l.keys()) - ) - logger.info(f'Adding {len(val_diff)} dummy labels to validation set codec.') - - val_codec = self.nn.codec.add_labels(val_diff) - self.val_set.dataset.encode(val_codec) - self.val_codec = val_codec - - if self.nn.one_channel_mode and self.train_set.dataset.im_mode != self.nn.one_channel_mode: - logger.warning(f'Neural network has been trained on mode {self.nn.one_channel_mode} images, ' - f'training set contains mode {self.train_set.dataset.im_mode} data. Consider setting `force_binarization`') - - if self.format_type != 'path' and self.nn.seg_type == 'bbox': - logger.warning('Neural network has been trained on bounding box image information but training set is polygonal.') - - self.nn.hyper_params = self.hparams.hyper_params - self.nn.model_type = 'recognition' - - if not self.nn.seg_type: - logger.info(f'Setting seg_type to {self.train_set.dataset.seg_type}.') - self.nn.seg_type = self.train_set.dataset.seg_type - - self.rec_nn = models.TorchSeqRecognizer(self.nn, train=None, device=None) - self.net = self.nn.nn - - torch.set_num_threads(max(self.num_workers, 1)) - - def train_dataloader(self): - return DataLoader(self.train_set, - batch_size=self.hparams.hyper_params['batch_size'], - num_workers=self.num_workers, - pin_memory=True, - shuffle=True, - collate_fn=collate_sequences) - - def val_dataloader(self): - return DataLoader(self.val_set, - shuffle=False, - batch_size=self.hparams.hyper_params['batch_size'], - num_workers=self.num_workers, - pin_memory=True, - collate_fn=collate_sequences, - worker_init_fn=_validation_worker_init_fn) - - def configure_callbacks(self): - callbacks = [] - if self.hparams.hyper_params['quit'] == 'early': - callbacks.append(EarlyStopping(monitor='val_accuracy', - mode='max', - patience=self.hparams.hyper_params['lag'], - stopping_threshold=1.0)) - - return callbacks - - # configuration of optimizers and learning rate schedulers - # -------------------------------------------------------- - # - # All schedulers are created internally with a frequency of step to enable - # batch-wise learning rate warmup. In lr_scheduler_step() calls to the - # scheduler are then only performed at the end of the epoch. - def configure_optimizers(self): - return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, - self.nn.nn.parameters(), - len_train_set=len(self.train_set), - loss_tracking_mode='max') - - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): - # update params - optimizer.step(closure=optimizer_closure) - - # linear warmup between 0 and the initial learning rate `lrate` in `warmup` - # steps. - if self.hparams.hyper_params['warmup'] and self.trainer.global_step < self.hparams.hyper_params['warmup']: - lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params['warmup']) - for pg in optimizer.param_groups: - pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] - - def lr_scheduler_step(self, scheduler, metric): - if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: - # step OneCycleLR each batch if not in warmup phase - if isinstance(scheduler, lr_scheduler.OneCycleLR): - scheduler.step() - # step every other scheduler epoch-wise - elif self.trainer.is_last_batch: - if metric is None: - scheduler.step() - else: - scheduler.step(metric) - - -class SegmentationModel(L.LightningModule): - def __init__(self, - hyper_params: Dict = None, - load_hyper_parameters: bool = False, - progress_callback: Callable[[str, int], Callable[[None], None]] = lambda string, length: lambda: None, - message: Callable[[str], None] = lambda *args, **kwargs: None, - output: str = 'model', - spec: str = default_specs.SEGMENTATION_SPEC, - model: Optional[Union['PathLike', str]] = None, - training_data: Union[Sequence[Union['PathLike', str]], Sequence[Segmentation]] = None, - evaluation_data: Optional[Union[Sequence[Union['PathLike', str]], Sequence[Segmentation]]] = None, - partition: Optional[float] = 0.9, - num_workers: int = 1, - force_binarization: bool = False, - format_type: Literal['path', 'alto', 'page', 'xml', None] = 'path', - suppress_regions: bool = False, - suppress_baselines: bool = False, - valid_regions: Optional[Sequence[str]] = None, - valid_baselines: Optional[Sequence[str]] = None, - merge_regions: Optional[Dict[str, str]] = None, - merge_baselines: Optional[Dict[str, str]] = None, - bounding_regions: Optional[Sequence[str]] = None, - resize: Literal['fail', 'both', 'new', 'add', 'union'] = 'fail', - topline: Union[bool, None] = False): - """ - A LightningModule encapsulating the training setup for a page - segmentation model. - - Setup parameters (load, training_data, evaluation_data, ....) are - named, model hyperparameters (everything in - `kraken.lib.default_specs.SEGMENTATION_HYPER_PARAMS`) are in in the - `hyper_params` argument. - - Args: - hyper_params (dict): Hyperparameter dictionary containing all fields - from - kraken.lib.default_specs.SEGMENTATION_HYPER_PARAMS - **kwargs: Setup parameters, i.e. CLI parameters of the segtrain() command. - """ - - super().__init__() - - self.best_epoch = -1 - self.best_metric = 0.0 - self.best_model = None - - self.model = model - self.num_workers = num_workers - - if resize == "add": - resize = "union" - warnings.warn("'add' value for resize has been deprecated. Use 'union' instead.", DeprecationWarning) - elif resize == "both": - resize = "new" - warnings.warn("'both' value for resize has been deprecated. Use 'new' instead.", DeprecationWarning) - self.resize = resize - - self.output = output - self.bounding_regions = bounding_regions - self.topline = topline - - hyper_params_ = default_specs.SEGMENTATION_HYPER_PARAMS.copy() - - if model: - logger.info(f'Loading existing model from {model}') - self.nn = vgsl.TorchVGSLModel.load_model(model) - - if self.nn.model_type not in [None, 'segmentation']: - raise ValueError(f'Model {model} is of type {self.nn.model_type} while `segmentation` is expected.') - - if load_hyper_parameters: - hp = self.nn.hyper_params - else: - hp = {} - hyper_params_.update(hp) - batch, channels, height, width = self.nn.input - else: - self.nn = None - - spec = spec.strip() - if spec[0] != '[' or spec[-1] != ']': - raise ValueError(f'VGSL spec "{spec}" not bracketed') - self.spec = spec - blocks = spec[1:-1].split(' ') - m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) - if not m: - raise ValueError(f'Invalid input spec {blocks[0]}') - batch, height, width, channels = [int(x) for x in m.groups()] - - if hyper_params: - hyper_params_.update(hyper_params) - - validate_hyper_parameters(hyper_params_) - self.hyper_params = hyper_params_ - self.save_hyperparameters() - - if format_type in ['xml', 'page', 'alto']: - logger.info(f'Parsing {len(training_data)} XML files for training data') - training_data = [XMLPage(file, format_type).to_container() for file in training_data] - if evaluation_data: - logger.info(f'Parsing {len(evaluation_data)} XML files for validation data') - evaluation_data = [XMLPage(file, format_type).to_container() for file in evaluation_data] - elif not format_type: - pass - else: - raise ValueError(f'format_type {format_type} not in [alto, page, xml, None].') - - if not training_data: - raise ValueError('No training data provided. Please add some.') - - transforms = ImageInputTransforms(batch, - height, - width, - channels, - self.hparams.hyper_params['padding'], - valid_norm=False, - force_binarization=force_binarization) - - self.example_input_array = torch.Tensor(batch, - channels, - height if height else 400, - width if width else 300) - - # set multiprocessing tensor sharing strategy - if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): - logger.debug('Setting multiprocessing tensor sharing strategy to file_system') - torch.multiprocessing.set_sharing_strategy('file_system') - - if not valid_regions: - valid_regions = None - if not valid_baselines: - valid_baselines = None - - if suppress_regions: - valid_regions = [] - merge_regions = None - if suppress_baselines: - valid_baselines = [] - merge_baselines = None - - train_set = BaselineSet(line_width=self.hparams.hyper_params['line_width'], - im_transforms=transforms, - augmentation=self.hparams.hyper_params['augment'], - valid_baselines=valid_baselines, - merge_baselines=merge_baselines, - valid_regions=valid_regions, - merge_regions=merge_regions) - - for page in training_data: - train_set.add(page) - - if evaluation_data: - val_set = BaselineSet(line_width=self.hparams.hyper_params['line_width'], - im_transforms=transforms, - augmentation=False, - valid_baselines=valid_baselines, - merge_baselines=merge_baselines, - valid_regions=valid_regions, - merge_regions=merge_regions) - - for page in evaluation_data: - val_set.add(page) - - train_set = Subset(train_set, range(len(train_set))) - val_set = Subset(val_set, range(len(val_set))) - else: - train_len = int(len(train_set)*partition) - val_len = len(train_set) - train_len - logger.info(f'No explicit validation data provided. Splitting off ' - f'{val_len} (of {len(train_set)}) samples to validation ' - 'set.') - train_set, val_set = random_split(train_set, (train_len, val_len)) - - if len(train_set) == 0: - raise ValueError('No valid training data provided. Please add some.') - - if len(val_set) == 0: - raise ValueError('No valid validation data provided. Please add some.') - - # overwrite class mapping in validation set - val_set.dataset.num_classes = train_set.dataset.num_classes - val_set.dataset.class_mapping = train_set.dataset.class_mapping - - self.train_set = train_set - self.val_set = val_set - - def forward(self, x): - return self.nn.nn(x) - - def training_step(self, batch, batch_idx): - input, target = batch['image'], batch['target'] - output, _ = self.nn.nn(input) - output = F.interpolate(output, size=(target.size(2), target.size(3))) - loss = self.nn.criterion(output, target) - self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True) - return loss - - def validation_step(self, batch, batch_idx): - x, y = batch['image'], batch['target'] - pred, _ = self.nn.nn(x) - # scale target to output size - y = F.interpolate(y, size=(pred.size(2), pred.size(3)), mode='nearest').int() - # Get regions for IoU metrics - reg_idxs = sorted(self.nn.user_metadata['class_mapping']['regions'].values()) - pred_reg = [:, reg_idxs, ...] - y_reg = y[:, reg_idxs, ...] - self.val_region_px_accuracy.update(pred_reg, y_reg) - self.val_region_mean_accuracy.update(pred_reg, y_reg) - self.val_region_mean_iu.update(pred_reg, y_reg) - self.val_region_freq_iu.update(pred_reg, y_reg) - - def on_validation_epoch_end(self): - if not self.trainer.sanity_checking: - pixel_accuracy = self.val_region_px_accuracy.compute() - mean_accuracy = self.val_region_mean_accuracy.compute() - mean_iu = self.val_region_mean_iu.compute() - freq_iu = self.val_region_freq_iu.compute() - - if mean_iu > self.best_metric: - logger.debug(f'Updating best region metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') - self.best_epoch = self.current_epoch - self.best_metric = mean_iu - - logger.info(f'validation run: accuracy {pixel_accuracy} mean_acc {mean_accuracy} mean_iu {mean_iu} freq_iu {freq_iu}') - - self.log('val_region_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_region_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_region_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_region_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) - - # reset metrics even if sanity checking - self.val_region_px_accuracy.reset() - self.val_region_mean_accuracy.reset() - self.val_region_mean_iu.reset() - self.val_region_freq_iu.reset() - - def setup(self, stage: Optional[str] = None): - # finalize models in case of appending/loading - if stage in [None, 'fit']: - if not self.model: - self.spec = f'[{self.spec[1:-1]} O2l{self.train_set.dataset.num_classes}]' - logger.info(f'Creating model {self.spec} with {self.train_set.dataset.num_classes} outputs') - nn = vgsl.TorchVGSLModel(self.spec) - if self.bounding_regions is not None: - nn.user_metadata['bounding_regions'] = self.bounding_regions - nn.user_metadata['topline'] = self.topline - self.nn = nn - else: - if self.train_set.dataset.class_mapping['baselines'].keys() != self.nn.user_metadata['class_mapping']['baselines'].keys() or \ - self.train_set.dataset.class_mapping['regions'].keys() != self.nn.user_metadata['class_mapping']['regions'].keys(): - - bl_diff = set(self.train_set.dataset.class_mapping['baselines'].keys()).symmetric_difference( - set(self.nn.user_metadata['class_mapping']['baselines'].keys())) - regions_diff = set(self.train_set.dataset.class_mapping['regions'].keys()).symmetric_difference( - set(self.nn.user_metadata['class_mapping']['regions'].keys())) - - if self.resize == 'fail': - raise ValueError(f'Training data and model class mapping differ (bl: {bl_diff}, regions: {regions_diff}') - elif self.resize == 'union': - new_bls = self.train_set.dataset.class_mapping['baselines'].keys() - self.nn.user_metadata['class_mapping']['baselines'].keys() - new_regions = self.train_set.dataset.class_mapping['regions'].keys() - self.nn.user_metadata['class_mapping']['regions'].keys() - cls_idx = max(max(self.nn.user_metadata['class_mapping']['baselines'].values()) if self.nn.user_metadata['class_mapping']['baselines'] else -1, # noqa - max(self.nn.user_metadata['class_mapping']['regions'].values()) if self.nn.user_metadata['class_mapping']['regions'] else -1) # noqa - logger.info(f'Adding {len(new_bls) + len(new_regions)} missing types to network output layer.') - self.nn.resize_output(cls_idx + len(new_bls) + len(new_regions) + 1) - for c in new_bls: - cls_idx += 1 - self.nn.user_metadata['class_mapping']['baselines'][c] = cls_idx - for c in new_regions: - cls_idx += 1 - self.nn.user_metadata['class_mapping']['regions'][c] = cls_idx - elif self.resize == 'new': - logger.info('Fitting network exactly to training set.') - new_bls = self.train_set.dataset.class_mapping['baselines'].keys() - self.nn.user_metadata['class_mapping']['baselines'].keys() - new_regions = self.train_set.dataset.class_mapping['regions'].keys() - self.nn.user_metadata['class_mapping']['regions'].keys() - del_bls = self.nn.user_metadata['class_mapping']['baselines'].keys() - self.train_set.dataset.class_mapping['baselines'].keys() - del_regions = self.nn.user_metadata['class_mapping']['regions'].keys() - self.train_set.dataset.class_mapping['regions'].keys() - - logger.info(f'Adding {len(new_bls) + len(new_regions)} missing ' - f'types and removing {len(del_bls) + len(del_regions)} to network output layer ') - cls_idx = max(max(self.nn.user_metadata['class_mapping']['baselines'].values()) if self.nn.user_metadata['class_mapping']['baselines'] else -1, # noqa - max(self.nn.user_metadata['class_mapping']['regions'].values()) if self.nn.user_metadata['class_mapping']['regions'] else -1) # noqa - - del_indices = [self.nn.user_metadata['class_mapping']['baselines'][x] for x in del_bls] - del_indices.extend(self.nn.user_metadata['class_mapping']['regions'][x] for x in del_regions) - self.nn.resize_output(cls_idx + len(new_bls) + len(new_regions) - - len(del_bls) - len(del_regions) + 1, del_indices) - - # delete old baseline/region types - cls_idx = min(min(self.nn.user_metadata['class_mapping']['baselines'].values()) if self.nn.user_metadata['class_mapping']['baselines'] else np.inf, # noqa - min(self.nn.user_metadata['class_mapping']['regions'].values()) if self.nn.user_metadata['class_mapping']['regions'] else np.inf) # noqa - - bls = {} - for k, v in sorted(self.nn.user_metadata['class_mapping']['baselines'].items(), key=lambda item: item[1]): - if k not in del_bls: - bls[k] = cls_idx - cls_idx += 1 - - regions = {} - for k, v in sorted(self.nn.user_metadata['class_mapping']['regions'].items(), key=lambda item: item[1]): - if k not in del_regions: - regions[k] = cls_idx - cls_idx += 1 - - self.nn.user_metadata['class_mapping']['baselines'] = bls - self.nn.user_metadata['class_mapping']['regions'] = regions - - # add new baseline/region types - cls_idx -= 1 - for c in new_bls: - cls_idx += 1 - self.nn.user_metadata['class_mapping']['baselines'][c] = cls_idx - for c in new_regions: - cls_idx += 1 - self.nn.user_metadata['class_mapping']['regions'][c] = cls_idx - else: - raise ValueError(f'invalid resize parameter value {self.resize}') - # backfill train_set/val_set mapping if key-equal as the actual - # numbering in the train_set might be different - self.train_set.dataset.class_mapping = self.nn.user_metadata['class_mapping'] - self.val_set.dataset.class_mapping = self.nn.user_metadata['class_mapping'] - - # updates model's hyper params with user-defined ones - self.nn.hyper_params = self.hparams.hyper_params - - # change topline/baseline switch - loc = {None: 'centerline', - True: 'topline', - False: 'baseline'} - - if 'topline' not in self.nn.user_metadata: - logger.warning(f'Setting baseline location to {loc[self.topline]} from unset model.') - elif self.nn.user_metadata['topline'] != self.topline: - from_loc = loc[self.nn.user_metadata['topline']] - logger.warning(f'Changing baseline location from {from_loc} to {loc[self.topline]}.') - self.nn.user_metadata['topline'] = self.topline - - logger.info('Training line types:') - for k, v in self.train_set.dataset.class_mapping['baselines'].items(): - logger.info(f' {k}\t{v}\t{self.train_set.dataset.class_stats["baselines"][k]}') - logger.info('Training region types:') - for k, v in self.train_set.dataset.class_mapping['regions'].items(): - logger.info(f' {k}\t{v}\t{self.train_set.dataset.class_stats["regions"][k]}') - - if len(self.train_set) == 0: - raise ValueError('No valid training data was provided to the train command. Please add valid XML data.') - - # set model type metadata field and dump class_mapping - self.nn.model_type = 'segmentation' - self.nn.user_metadata['class_mapping'] = self.val_set.dataset.class_mapping - - # for model size/trainable parameter output - self.net = self.nn.nn - - torch.set_num_threads(max(self.num_workers, 1)) - - # set up validation metrics after output classes have been determined - # baseline metrics - # region metrics - num_regions = len(self.val_set.dataset.class_mapping['regions']) - self.val_region_px_accuracy = MultilabelAccuracy(average='micro', num_labels=num_regions) - self.val_region_mean_accuracy = MultilabelAccuracy(average='macro', num_labels=num_regions) - self.val_region_mean_iu = MultilabelJaccardIndex(average='macro', num_labels=num_regions) - self.val_region_freq_iu = MultilabelJaccardIndex(average='weighted', num_labels=num_regions) - - - def train_dataloader(self): - return DataLoader(self.train_set, - batch_size=1, - num_workers=self.num_workers, - shuffle=True, - pin_memory=True) - - def val_dataloader(self): - return DataLoader(self.val_set, - shuffle=False, - batch_size=1, - num_workers=self.num_workers, - pin_memory=True) - - def configure_callbacks(self): - callbacks = [] - if self.hparams.hyper_params['quit'] == 'early': - callbacks.append(EarlyStopping(monitor='val_mean_iu', - mode='max', - patience=self.hparams.hyper_params['lag'], - stopping_threshold=1.0)) - - return callbacks - - # configuration of optimizers and learning rate schedulers - # -------------------------------------------------------- - # - # All schedulers are created internally with a frequency of step to enable - # batch-wise learning rate warmup. In lr_scheduler_step() calls to the - # scheduler are then only performed at the end of the epoch. - def configure_optimizers(self): - return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, - self.nn.nn.parameters(), - len_train_set=len(self.train_set), - loss_tracking_mode='max') - - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): - # update params - optimizer.step(closure=optimizer_closure) - - # linear warmup between 0 and the initial learning rate `lrate` in `warmup` - # steps. - if self.hparams.hyper_params['warmup'] and self.trainer.global_step < self.hparams.hyper_params['warmup']: - lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params['warmup']) - for pg in optimizer.param_groups: - pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] - - def lr_scheduler_step(self, scheduler, metric): - if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: - # step OneCycleLR each batch if not in warmup phase - if isinstance(scheduler, lr_scheduler.OneCycleLR): - scheduler.step() - # step every other scheduler epoch-wise - elif self.trainer.is_last_batch: - if metric is None: - scheduler.step() - else: - scheduler.step(metric) - - -def _configure_optimizer_and_lr_scheduler(hparams, params, len_train_set=None, loss_tracking_mode='max'): - optimizer = hparams.get("optimizer") - lrate = hparams.get("lrate") - momentum = hparams.get("momentum") - weight_decay = hparams.get("weight_decay") - schedule = hparams.get("schedule") - gamma = hparams.get("gamma") - cos_t_max = hparams.get("cos_t_max") - cos_min_lr = hparams.get("cos_min_lr") - step_size = hparams.get("step_size") - rop_factor = hparams.get("rop_factor") - rop_patience = hparams.get("rop_patience") - epochs = hparams.get("epochs") - completed_epochs = hparams.get("completed_epochs") - - # XXX: Warmup is not configured here because it needs to be manually done in optimizer_step() - logger.debug(f'Constructing {optimizer} optimizer (lr: {lrate}, momentum: {momentum})') - if optimizer == 'Adam': - optim = torch.optim.Adam(params, lr=lrate, weight_decay=weight_decay) - else: - optim = getattr(torch.optim, optimizer)(params, - lr=lrate, - momentum=momentum, - weight_decay=weight_decay) - lr_sched = {} - if schedule == 'exponential': - lr_sched = {'scheduler': lr_scheduler.ExponentialLR(optim, gamma, last_epoch=completed_epochs-1), - 'interval': 'step'} - elif schedule == 'cosine': - lr_sched = {'scheduler': lr_scheduler.CosineAnnealingLR(optim, - cos_t_max, - cos_min_lr, - last_epoch=completed_epochs-1), - 'interval': 'step'} - elif schedule == 'step': - lr_sched = {'scheduler': lr_scheduler.StepLR(optim, step_size, gamma, last_epoch=completed_epochs-1), - 'interval': 'step'} - elif schedule == 'reduceonplateau': - lr_sched = {'scheduler': lr_scheduler.ReduceLROnPlateau(optim, - mode=loss_tracking_mode, - factor=rop_factor, - patience=rop_patience), - 'interval': 'step'} - elif schedule == '1cycle': - if epochs <= 0: - raise ValueError('1cycle learning rate scheduler selected but ' - 'number of epochs is less than 0 ' - f'({epochs}).') - last_epoch = completed_epochs*len_train_set if completed_epochs else -1 - lr_sched = {'scheduler': lr_scheduler.OneCycleLR(optim, - max_lr=lrate, - epochs=epochs, - steps_per_epoch=len_train_set, - last_epoch=last_epoch), - 'interval': 'step'} - elif schedule != 'constant': - raise ValueError(f'Unsupported learning rate scheduler {schedule}.') - - ret = {'optimizer': optim} - if lr_sched: - ret['lr_scheduler'] = lr_sched - - if schedule == 'reduceonplateau': - lr_sched['monitor'] = 'val_metric' - lr_sched['strict'] = False - lr_sched['reduce_on_plateau'] = True - - return ret diff --git a/kraken/lib/train/__init__.py b/kraken/lib/train/__init__.py new file mode 100644 index 00000000..84069daa --- /dev/null +++ b/kraken/lib/train/__init__.py @@ -0,0 +1,20 @@ +# +# Copyright 2023 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Tools for segmentation and recognition training +""" +from .trainer import KrakenTrainer +from .recognition import RecognitionModel +from .segmentation import SegmentationModel diff --git a/kraken/lib/train/recognition.py b/kraken/lib/train/recognition.py new file mode 100644 index 00000000..be99a8aa --- /dev/null +++ b/kraken/lib/train/recognition.py @@ -0,0 +1,552 @@ +# +# Copyright 2015 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Text recognition model +""" +import re +import torch +import logging +import warnings +import numpy as np +import lightning as L +from typing import (TYPE_CHECKING, Any, Dict, Literal, Optional, + Sequence, Union) +from functools import partial + +from lightning.pytorch.callbacks import EarlyStopping + +from torch.optim import lr_scheduler +from torch.utils.data import DataLoader, Subset, random_split +from torchmetrics.text import CharErrorRate, WordErrorRate + +from kraken.containers import Segmentation +from kraken.lib import default_specs, models, vgsl +from kraken.lib.codec import PytorchCodec +from kraken.lib.dataset import (ArrowIPCRecognitionDataset, GroundTruthDataset, + ImageInputTransforms, PolygonGTDataset, + collate_sequences) +from kraken.lib.exceptions import KrakenEncodeException, KrakenInputException +from kraken.lib.util import make_printable, parse_gt_path +from kraken.lib.xml import XMLPage + +from .utils import _configure_optimizer_and_lr_scheduler + +if TYPE_CHECKING: + from os import PathLike + +logger = logging.getLogger(__name__) + + +class RecognitionModel(L.LightningModule): + """ + A LightningModule encapsulating the training setup for a text + recognition model. + + Setup parameters (load, training_data, evaluation_data, ....) are + named, model hyperparameters (everything in + `kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS`) are in in the + `hyper_params` argument. + + Args: + hyper_params (dict): Hyperparameter dictionary containing all fields + from + kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS + **kwargs: Setup parameters, i.e. CLI parameters of the train() command. + """ + def __init__(self, + hyper_params: Dict[str, Any] = None, + output: str = 'model', + spec: str = default_specs.RECOGNITION_SPEC, + append: Optional[int] = None, + model: Optional[Union['PathLike', str]] = None, + reorder: Union[bool, str] = True, + training_data: Union[Sequence[Union['PathLike', str]], Sequence[Dict[str, Any]]] = None, + evaluation_data: Optional[Union[Sequence[Union['PathLike', str]], Sequence[Dict[str, Any]]]] = None, + partition: Optional[float] = 0.9, + binary_dataset_split: bool = False, + num_workers: int = 1, + load_hyper_parameters: bool = False, + force_binarization: bool = False, + format_type: Literal['path', 'alto', 'page', 'xml', 'binary'] = 'path', + codec: Optional[Dict] = None, + resize: Literal['fail', 'both', 'new', 'add', 'union'] = 'fail', + legacy_polygons: bool = False): + super().__init__() + self.legacy_polygons = legacy_polygons + hyper_params_ = default_specs.RECOGNITION_HYPER_PARAMS.copy() + if model: + logger.info(f'Loading existing model from {model} ') + self.nn = vgsl.TorchVGSLModel.load_model(model) + + if self.nn.model_type not in [None, 'recognition']: + raise ValueError(f'Model {model} is of type {self.nn.model_type} while `recognition` is expected.') + + if load_hyper_parameters: + hp = self.nn.hyper_params + else: + hp = {} + hyper_params_.update(hp) + else: + self.nn = None + + if hyper_params: + hyper_params_.update(hyper_params) + self.hyper_params = hyper_params_ + self.save_hyperparameters() + + self.reorder = reorder + self.append = append + self.model = model + self.num_workers = num_workers + if resize == "add": + resize = "union" + warnings.warn("'add' value for resize has been deprecated. Use 'union' instead.", DeprecationWarning) + elif resize == "both": + resize = "new" + warnings.warn("'both' value for resize has been deprecated. Use 'new' instead.", DeprecationWarning) + + self.resize = resize + self.format_type = format_type + self.output = output + + self.best_epoch = -1 + self.best_metric = 0.0 + self.best_model = None + + DatasetClass = GroundTruthDataset + valid_norm = True + if format_type in ['xml', 'page', 'alto']: + logger.info(f'Parsing {len(training_data)} XML files for training data') + training_data = [{'page': XMLPage(file, format_type).to_container()} for file in training_data] + if evaluation_data: + logger.info(f'Parsing {len(evaluation_data)} XML files for validation data') + evaluation_data = [{'page': XMLPage(file, format_type).to_container()} for file in evaluation_data] + if binary_dataset_split: + logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') + binary_dataset_split = False + DatasetClass = partial(PolygonGTDataset, legacy_polygons=legacy_polygons) + valid_norm = False + elif format_type == 'binary': + DatasetClass = ArrowIPCRecognitionDataset + valid_norm = False + logger.info(f'Got {len(training_data)} binary dataset files for training data') + training_data = [{'file': file} for file in training_data] + if evaluation_data: + logger.info(f'Got {len(evaluation_data)} binary dataset files for validation data') + evaluation_data = [{'file': file} for file in evaluation_data] + elif format_type == 'path': + if force_binarization: + logger.warning('Forced binarization enabled in `path` mode. Will be ignored.') + force_binarization = False + if binary_dataset_split: + logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') + binary_dataset_split = False + logger.info(f'Got {len(training_data)} line strip images for training data') + training_data = [{'line': parse_gt_path(im)} for im in training_data] + if evaluation_data: + logger.info(f'Got {len(evaluation_data)} line strip images for validation data') + evaluation_data = [{'line': parse_gt_path(im)} for im in evaluation_data] + valid_norm = True + # format_type is None. Determine training type from container class types + elif not format_type: + if training_data[0].type == 'baselines': + DatasetClass = partial(PolygonGTDataset, legacy_polygons=legacy_polygons) + valid_norm = False + else: + if force_binarization: + logger.warning('Forced binarization enabled with box lines. Will be ignored.') + force_binarization = False + if binary_dataset_split: + logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') + binary_dataset_split = False + samples = [] + for sample in training_data: + if isinstance(sample, Segmentation): + samples.append({'page': sample}) + else: + samples.append({'line': sample}) + training_data = samples + if evaluation_data: + samples = [] + for sample in evaluation_data: + if isinstance(sample, Segmentation): + samples.append({'page': sample}) + else: + samples.append({'line': sample}) + evaluation_data = samples + else: + raise ValueError(f'format_type {format_type} not in [alto, page, xml, path, binary].') + + spec = spec.strip() + if spec[0] != '[' or spec[-1] != ']': + raise ValueError(f'VGSL spec {spec} not bracketed') + self.spec = spec + # preparse input sizes from vgsl string to seed ground truth data set + # sizes and dimension ordering. + if not self.nn: + blocks = spec[1:-1].split(' ') + m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) + if not m: + raise ValueError(f'Invalid input spec {blocks[0]}') + batch, height, width, channels = [int(x) for x in m.groups()] + else: + batch, channels, height, width = self.nn.input + + self.transforms = ImageInputTransforms(batch, + height, + width, + channels, + (self.hparams.hyper_params['pad'], 0), + valid_norm, + force_binarization) + + self.example_input_array = torch.Tensor(batch, + channels, + height if height else 32, + width if width else 400) + + if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): + logger.debug('Setting multiprocessing tensor sharing strategy to file_system') + torch.multiprocessing.set_sharing_strategy('file_system') + + val_set = None + if evaluation_data: + train_set = self._build_dataset(DatasetClass, training_data) + self.train_set = Subset(train_set, range(len(train_set))) + val_set = self._build_dataset(DatasetClass, evaluation_data) + self.val_set = Subset(val_set, range(len(val_set))) + elif binary_dataset_split: + train_set = self._build_dataset(DatasetClass, training_data, split_filter='train') + self.train_set = Subset(train_set, range(len(train_set))) + val_set = self._build_dataset(DatasetClass, training_data, split_filter='validation') + self.val_set = Subset(val_set, range(len(val_set))) + logger.info(f'Found {len(self.train_set)} (train) / {len(self.val_set)} (val) samples in pre-encoded dataset') + else: + train_set = self._build_dataset(DatasetClass, training_data) + train_len = int(len(train_set)*partition) + val_len = len(train_set) - train_len + logger.info(f'No explicit validation data provided. Splitting off ' + f'{val_len} (of {len(train_set)}) samples to validation ' + 'set. (Will disable alphabet mismatch detection.)') + self.train_set, self.val_set = random_split(train_set, (train_len, val_len)) + + if len(self.train_set) == 0 or len(self.val_set) == 0: + raise ValueError('No valid training data was provided to the train ' + 'command. Please add valid XML, line, or binary data.') + + if format_type == 'binary': + legacy_train_status = self.train_set.dataset.legacy_polygons_status + if self.val_set.dataset.legacy_polygons_status != legacy_train_status: + logger.warning('Train and validation set have different legacy ' + f'polygon status: {legacy_train_status} and ' + f'{self.val_set.dataset.legacy_polygons_status}. Train set ' + 'status prevails.') + if legacy_train_status == "mixed": + logger.warning('Mixed legacy polygon status in training dataset. Consider recompilation.') + legacy_train_status = False + if legacy_polygons != legacy_train_status: + logger.warning(f'Setting dataset legacy polygon status to {legacy_train_status} based on training set.') + self.legacy_polygons = legacy_train_status + + logger.info(f'Training set {len(self.train_set)} lines, validation set ' + f'{len(self.val_set)} lines, alphabet {len(train_set.alphabet)} ' + 'symbols') + alpha_diff_only_train = set(self.train_set.dataset.alphabet).difference(set(self.val_set.dataset.alphabet)) + alpha_diff_only_val = set(self.val_set.dataset.alphabet).difference(set(self.train_set.dataset.alphabet)) + if alpha_diff_only_train: + logger.warning(f'alphabet mismatch: chars in training set only: ' + f'{alpha_diff_only_train} (not included in accuracy test ' + 'during training)') + if alpha_diff_only_val: + logger.warning(f'alphabet mismatch: chars in validation set only: {alpha_diff_only_val} (not trained)') + logger.info('grapheme\tcount') + for k, v in sorted(train_set.alphabet.items(), key=lambda x: x[1], reverse=True): + char = make_printable(k) + if char == k: + char = '\t' + char + logger.info(f'{char}\t{v}') + + if codec: + logger.info('Instantiating codec') + self.codec = PytorchCodec(codec) + for k, v in self.codec.c2l.items(): + char = make_printable(k) + if char == k: + char = '\t' + char + logger.info(f'{char}\t{v}') + else: + self.codec = None + + logger.info('Encoding training set') + + self.val_cer = CharErrorRate() + self.val_wer = WordErrorRate() + + def _build_dataset(self, + DatasetClass, + training_data, + **kwargs): + dataset = DatasetClass(normalization=self.hparams.hyper_params['normalization'], + whitespace_normalization=self.hparams.hyper_params['normalize_whitespace'], + reorder=self.reorder, + im_transforms=self.transforms, + augmentation=self.hparams.hyper_params['augment'], + **kwargs) + + for sample in training_data: + try: + dataset.add(**sample) + except KrakenInputException as e: + logger.warning(str(e)) + if self.format_type == 'binary' and (self.hparams.hyper_params['normalization'] or + self.hparams.hyper_params['normalize_whitespace'] or + self.reorder): + logger.debug('Text transformations modifying alphabet selected. Rebuilding alphabet') + dataset.rebuild_alphabet() + + return dataset + + def forward(self, x, seq_lens=None): + return self.net(x, seq_lens) + + def training_step(self, batch, batch_idx): + input, target = batch['image'], batch['target'] + # sequence batch + if 'seq_lens' in batch: + seq_lens, label_lens = batch['seq_lens'], batch['target_lens'] + target = (target, label_lens) + o = self.net(input, seq_lens) + else: + o = self.net(input) + + seq_lens = o[1] + output = o[0] + target_lens = target[1] + target = target[0] + # height should be 1 by now + if output.size(2) != 1: + raise KrakenInputException('Expected dimension 3 to be 1, actual {}'.format(output.size(2))) + output = output.squeeze(2) + # NCW -> WNC + loss = self.nn.criterion(output.permute(2, 0, 1), # type: ignore + target, + seq_lens, + target_lens) + self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True) + return loss + + def validation_step(self, batch, batch_idx): + pred = self.rec_nn.predict_string(batch['image'], batch['seq_lens']) + idx = 0 + decoded_targets = [] + for offset in batch['target_lens']: + decoded_targets.append(''.join([x[0] for x in self.val_codec.decode([(x, 0, 0, 0) for x in batch['target'][idx:idx+offset]])])) + idx += offset + self.val_cer.update(pred, decoded_targets) + self.val_wer.update(pred, decoded_targets) + + if self.logger and self.trainer.state.stage != 'sanity_check' and self.hparams.hyper_params["batch_size"] * batch_idx < 16: + for i in range(self.hparams.hyper_params["batch_size"]): + count = self.hparams.hyper_params["batch_size"] * batch_idx + i + if count < 16: + self.logger.experiment.add_image(f'Validation #{count}, target: {decoded_targets[i]}', + batch['image'][i], + self.global_step, + dataformats="CHW") + self.logger.experiment.add_text(f'Validation #{count}, target: {decoded_targets[i]}', + pred[i], + self.global_step) + + def on_validation_epoch_end(self): + if not self.trainer.sanity_checking: + accuracy = 1.0 - self.val_cer.compute() + word_accuracy = 1.0 - self.val_wer.compute() + + if accuracy > self.best_metric: + logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {accuracy} ({self.current_epoch})') + self.best_epoch = self.current_epoch + self.best_metric = accuracy + logger.info(f'validation run: total chars {self.val_cer.total} errors {self.val_cer.errors} accuracy {accuracy}') + self.log('val_accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_word_accuracy', word_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_metric', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) + # reset metrics even if not sanity checking + self.val_cer.reset() + self.val_wer.reset() + + def setup(self, stage: Optional[str] = None): + # finalize models in case of appending/loading + if stage in [None, 'fit']: + + # Log a few sample images before the datasets are encoded. + # This is only possible for Arrow datasets, because the + # other dataset types can only be accessed after encoding + if self.logger and isinstance(self.train_set.dataset, ArrowIPCRecognitionDataset): + for i in range(min(len(self.train_set), 16)): + idx = np.random.randint(len(self.train_set)) + sample = self.train_set[idx] + self.logger.experiment.add_image(f'train_set sample #{i}: {sample["target"]}', sample['image']) + + if self.append: + self.train_set.dataset.encode(self.codec) + # now we can create a new model + self.spec = '[{} O1c{}]'.format(self.spec[1:-1], self.train_set.dataset.codec.max_label + 1) + logger.info(f'Appending {self.spec} to existing model {self.nn.spec} after {self.append}') + self.nn.append(self.append, self.spec) + self.nn.add_codec(self.train_set.dataset.codec) + logger.info(f'Assembled model spec: {self.nn.spec}') + elif self.model: + self.spec = self.nn.spec + + # prefer explicitly given codec over network codec if mode is 'new' + codec = self.codec if (self.codec and self.resize == 'new') else self.nn.codec + + codec.strict = True + + try: + self.train_set.dataset.encode(codec) + except KrakenEncodeException: + alpha_diff = set(self.train_set.dataset.alphabet).difference( + set(codec.c2l.keys()) + ) + if self.resize == 'fail': + raise KrakenInputException(f'Training data and model codec alphabets mismatch: {alpha_diff}') + elif self.resize == 'union': + logger.info(f'Resizing codec to include ' + f'{len(alpha_diff)} new code points') + # Construct two codecs: + # 1. training codec containing only the vocabulary in the training dataset + # 2. validation codec = training codec + validation set vocabulary + # This keep the codec in the model from being 'polluted' by non-trained characters. + train_codec = codec.add_labels(alpha_diff) + self.nn.add_codec(train_codec) + logger.info(f'Resizing last layer in network to {train_codec.max_label+1} outputs') + self.nn.resize_output(train_codec.max_label + 1) + self.train_set.dataset.encode(train_codec) + elif self.resize == 'new': + logger.info(f'Resizing network or given codec to ' + f'{len(self.train_set.dataset.alphabet)} ' + f'code sequences') + # same codec procedure as above, just with merging. + self.train_set.dataset.encode(None) + train_codec, del_labels = codec.merge(self.train_set.dataset.codec) + # Switch codec. + self.nn.add_codec(train_codec) + logger.info(f'Deleting {len(del_labels)} output classes from network ' + f'({len(codec)-len(del_labels)} retained)') + self.nn.resize_output(train_codec.max_label + 1, del_labels) + self.train_set.dataset.encode(train_codec) + else: + raise ValueError(f'invalid resize parameter value {self.resize}') + self.nn.codec.strict = False + self.spec = self.nn.spec + else: + self.train_set.dataset.encode(self.codec) + logger.info(f'Creating new model {self.spec} with {self.train_set.dataset.codec.max_label+1} outputs') + self.spec = '[{} O1c{}]'.format(self.spec[1:-1], self.train_set.dataset.codec.max_label + 1) + self.nn = vgsl.TorchVGSLModel(self.spec) + self.nn.use_legacy_polygons = self.legacy_polygons + # initialize weights + self.nn.init_weights() + self.nn.add_codec(self.train_set.dataset.codec) + + val_diff = set(self.val_set.dataset.alphabet).difference( + set(self.train_set.dataset.codec.c2l.keys()) + ) + logger.info(f'Adding {len(val_diff)} dummy labels to validation set codec.') + + val_codec = self.nn.codec.add_labels(val_diff) + self.val_set.dataset.encode(val_codec) + self.val_codec = val_codec + + if self.nn.one_channel_mode and self.train_set.dataset.im_mode != self.nn.one_channel_mode: + logger.warning(f'Neural network has been trained on mode {self.nn.one_channel_mode} images, ' + f'training set contains mode {self.train_set.dataset.im_mode} data. Consider setting `force_binarization`') + + if self.format_type != 'path' and self.nn.seg_type == 'bbox': + logger.warning('Neural network has been trained on bounding box image information but training set is polygonal.') + + self.nn.hyper_params = self.hparams.hyper_params + self.nn.model_type = 'recognition' + + if not self.nn.seg_type: + logger.info(f'Setting seg_type to {self.train_set.dataset.seg_type}.') + self.nn.seg_type = self.train_set.dataset.seg_type + + self.rec_nn = models.TorchSeqRecognizer(self.nn, train=None, device=None) + self.net = self.nn.nn + + torch.set_num_threads(max(self.num_workers, 1)) + + def train_dataloader(self): + return DataLoader(self.train_set, + batch_size=self.hparams.hyper_params['batch_size'], + num_workers=self.num_workers, + pin_memory=True, + shuffle=True, + collate_fn=collate_sequences) + + def val_dataloader(self): + return DataLoader(self.val_set, + shuffle=False, + batch_size=self.hparams.hyper_params['batch_size'], + num_workers=self.num_workers, + pin_memory=True, + collate_fn=collate_sequences) + + def configure_callbacks(self): + callbacks = [] + if self.hparams.hyper_params['quit'] == 'early': + callbacks.append(EarlyStopping(monitor='val_accuracy', + mode='max', + patience=self.hparams.hyper_params['lag'], + stopping_threshold=1.0)) + + return callbacks + + # configuration of optimizers and learning rate schedulers + # -------------------------------------------------------- + # + # All schedulers are created internally with a frequency of step to enable + # batch-wise learning rate warmup. In lr_scheduler_step() calls to the + # scheduler are then only performed at the end of the epoch. + def configure_optimizers(self): + return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, + self.nn.nn.parameters(), + len_train_set=len(self.train_set), + loss_tracking_mode='max') + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): + # update params + optimizer.step(closure=optimizer_closure) + + # linear warmup between 0 and the initial learning rate `lrate` in `warmup` + # steps. + if self.hparams.hyper_params['warmup'] and self.trainer.global_step < self.hparams.hyper_params['warmup']: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params['warmup']) + for pg in optimizer.param_groups: + pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] + + def lr_scheduler_step(self, scheduler, metric): + if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: + # step OneCycleLR each batch if not in warmup phase + if isinstance(scheduler, lr_scheduler.OneCycleLR): + scheduler.step() + # step every other scheduler epoch-wise + elif self.trainer.is_last_batch: + if metric is None: + scheduler.step() + else: + scheduler.step(metric) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py new file mode 100644 index 00000000..cfaa3e2f --- /dev/null +++ b/kraken/lib/train/segmentation.py @@ -0,0 +1,480 @@ +# +# Copyright 2015 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Training loop interception helpers +""" +import re +import logging +import warnings + +import numpy as np +import lightning as L +import torch +import torch.nn.functional as F + +from typing import (TYPE_CHECKING, Callable, Dict, Literal, Optional, Sequence, + Union) + +from torch.optim import lr_scheduler +from torch.utils.data import DataLoader, Subset, random_split +from lightning.pytorch.callbacks import EarlyStopping +from torchmetrics.classification import (MultilabelAccuracy, + MultilabelJaccardIndex) + +from kraken.containers import Segmentation +from kraken.lib import default_specs, vgsl +from kraken.lib.dataset import BaselineSet, ImageInputTransforms + +from kraken.lib.xml import XMLPage +from kraken.lib.models import validate_hyper_parameters +from kraken.lib.segmentation import vectorize_lines + +from .utils import _configure_optimizer_and_lr_scheduler + +if TYPE_CHECKING: + from os import PathLike + +logger = logging.getLogger(__name__) + + +class SegmentationModel(L.LightningModule): + """ + A LightningModule encapsulating the training setup for a page + segmentation model. + + Setup parameters (load, training_data, evaluation_data, ....) are + named, model hyperparameters (everything in + `kraken.lib.default_specs.SEGMENTATION_HYPER_PARAMS`) are in in the + `hyper_params` argument. + + Args: + hyper_params (dict): Hyperparameter dictionary containing all fields + from + kraken.lib.default_specs.SEGMENTATION_HYPER_PARAMS + **kwargs: Setup parameters, i.e. CLI parameters of the segtrain() command. + """ + def __init__(self, + hyper_params: Dict = None, + load_hyper_parameters: bool = False, + progress_callback: Callable[[str, int], Callable[[None], None]] = lambda string, length: lambda: None, + message: Callable[[str], None] = lambda *args, **kwargs: None, + output: str = 'model', + spec: str = default_specs.SEGMENTATION_SPEC, + model: Optional[Union['PathLike', str]] = None, + training_data: Union[Sequence[Union['PathLike', str]], Sequence[Segmentation]] = None, + evaluation_data: Optional[Union[Sequence[Union['PathLike', str]], Sequence[Segmentation]]] = None, + partition: Optional[float] = 0.9, + num_workers: int = 1, + force_binarization: bool = False, + format_type: Literal['path', 'alto', 'page', 'xml', None] = 'path', + suppress_regions: bool = False, + suppress_baselines: bool = False, + valid_regions: Optional[Sequence[str]] = None, + valid_baselines: Optional[Sequence[str]] = None, + merge_regions: Optional[Dict[str, str]] = None, + merge_baselines: Optional[Dict[str, str]] = None, + bounding_regions: Optional[Sequence[str]] = None, + resize: Literal['fail', 'both', 'new', 'add', 'union'] = 'fail', + topline: Union[bool, None] = False): + super().__init__() + + self.best_epoch = -1 + self.best_metric = 0.0 + self.best_model = None + + self.model = model + self.num_workers = num_workers + + if resize == "add": + resize = "union" + warnings.warn("'add' value for resize has been deprecated. Use 'union' instead.", DeprecationWarning) + elif resize == "both": + resize = "new" + warnings.warn("'both' value for resize has been deprecated. Use 'new' instead.", DeprecationWarning) + self.resize = resize + + self.output = output + self.bounding_regions = bounding_regions + self.topline = topline + + hyper_params_ = default_specs.SEGMENTATION_HYPER_PARAMS.copy() + + if model: + logger.info(f'Loading existing model from {model}') + self.nn = vgsl.TorchVGSLModel.load_model(model) + + if self.nn.model_type not in [None, 'segmentation']: + raise ValueError(f'Model {model} is of type {self.nn.model_type} while `segmentation` is expected.') + + if load_hyper_parameters: + hp = self.nn.hyper_params + else: + hp = {} + hyper_params_.update(hp) + batch, channels, height, width = self.nn.input + else: + self.nn = None + + spec = spec.strip() + if spec[0] != '[' or spec[-1] != ']': + raise ValueError(f'VGSL spec "{spec}" not bracketed') + self.spec = spec + blocks = spec[1:-1].split(' ') + m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) + if not m: + raise ValueError(f'Invalid input spec {blocks[0]}') + batch, height, width, channels = [int(x) for x in m.groups()] + + if hyper_params: + hyper_params_.update(hyper_params) + + validate_hyper_parameters(hyper_params_) + self.hyper_params = hyper_params_ + self.save_hyperparameters() + + if format_type in ['xml', 'page', 'alto']: + logger.info(f'Parsing {len(training_data)} XML files for training data') + training_data = [XMLPage(file, format_type).to_container() for file in training_data] + if evaluation_data: + logger.info(f'Parsing {len(evaluation_data)} XML files for validation data') + evaluation_data = [XMLPage(file, format_type).to_container() for file in evaluation_data] + elif not format_type: + pass + else: + raise ValueError(f'format_type {format_type} not in [alto, page, xml, None].') + + if not training_data: + raise ValueError('No training data provided. Please add some.') + + transforms = ImageInputTransforms(batch, + height, + width, + channels, + self.hparams.hyper_params['padding'], + valid_norm=False, + force_binarization=force_binarization) + + self.example_input_array = torch.Tensor(batch, + channels, + height if height else 400, + width if width else 300) + + # set multiprocessing tensor sharing strategy + if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): + logger.debug('Setting multiprocessing tensor sharing strategy to file_system') + torch.multiprocessing.set_sharing_strategy('file_system') + + if not valid_regions: + valid_regions = None + if not valid_baselines: + valid_baselines = None + + if suppress_regions: + valid_regions = [] + merge_regions = None + if suppress_baselines: + valid_baselines = [] + merge_baselines = None + + train_set = BaselineSet(line_width=self.hparams.hyper_params['line_width'], + im_transforms=transforms, + augmentation=self.hparams.hyper_params['augment'], + valid_baselines=valid_baselines, + merge_baselines=merge_baselines, + valid_regions=valid_regions, + merge_regions=merge_regions) + + for page in training_data: + train_set.add(page) + + if evaluation_data: + val_set = BaselineSet(line_width=self.hparams.hyper_params['line_width'], + im_transforms=transforms, + augmentation=False, + valid_baselines=valid_baselines, + merge_baselines=merge_baselines, + valid_regions=valid_regions, + merge_regions=merge_regions) + + for page in evaluation_data: + val_set.add(page) + + train_set = Subset(train_set, range(len(train_set))) + val_set = Subset(val_set, range(len(val_set))) + else: + train_len = int(len(train_set)*partition) + val_len = len(train_set) - train_len + logger.info(f'No explicit validation data provided. Splitting off ' + f'{val_len} (of {len(train_set)}) samples to validation ' + 'set.') + train_set, val_set = random_split(train_set, (train_len, val_len)) + + if len(train_set) == 0: + raise ValueError('No valid training data provided. Please add some.') + + if len(val_set) == 0: + raise ValueError('No valid validation data provided. Please add some.') + + # overwrite class mapping in validation set + val_set.dataset.num_classes = train_set.dataset.num_classes + val_set.dataset.class_mapping = train_set.dataset.class_mapping + + self.train_set = train_set + self.val_set = val_set + + def forward(self, x): + return self.nn.nn(x) + + def training_step(self, batch, batch_idx): + input, target = batch['image'], batch['target'] + output, _ = self.nn.nn(input) + output = F.interpolate(output, size=(target.size(2), target.size(3))) + loss = self.nn.criterion(output, target) + self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch['image'], batch['target'] + pred, _ = self.nn.nn(x) + # scale target to output size + y = F.interpolate(y, size=(pred.size(2), pred.size(3)), mode='nearest').int() + # Get regions for IoU metrics + reg_idxs = sorted(self.nn.user_metadata['class_mapping']['regions'].values()) + pred_reg = pred[:, reg_idxs, ...] + y_reg = y[:, reg_idxs, ...] + self.val_region_px_accuracy.update(pred_reg, y_reg) + self.val_region_mean_accuracy.update(pred_reg, y_reg) + self.val_region_mean_iu.update(pred_reg, y_reg) + self.val_region_freq_iu.update(pred_reg, y_reg) + # vectorize lines + st_sep = self.nn.user_metadata['class_mapping']['aux']['_start_separator'] + end_sep = self.nn.user_metadata['class_mapping']['aux']['_end_separator'] + line_idxs = sorted(self.nn.user_metadata['class_mapping']['lines'].values()) + for line_idx in line_idxs: + pred_bl = vectorize_lines(pred[:, [st_sep, end_sep, line_idx], ...], text_direction='horizontal') + + def on_validation_epoch_end(self): + if not self.trainer.sanity_checking: + pixel_accuracy = self.val_region_px_accuracy.compute() + mean_accuracy = self.val_region_mean_accuracy.compute() + mean_iu = self.val_region_mean_iu.compute() + freq_iu = self.val_region_freq_iu.compute() + + if mean_iu > self.best_metric: + logger.debug(f'Updating best region metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') + self.best_epoch = self.current_epoch + self.best_metric = mean_iu + + logger.info(f'validation run: accuracy {pixel_accuracy} mean_acc {mean_accuracy} mean_iu {mean_iu} freq_iu {freq_iu}') + + self.log('val_region_accuracy', pixel_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_region_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_region_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_region_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) + + # reset metrics even if sanity checking + self.val_region_px_accuracy.reset() + self.val_region_mean_accuracy.reset() + self.val_region_mean_iu.reset() + self.val_region_freq_iu.reset() + + def setup(self, stage: Optional[str] = None): + # finalize models in case of appending/loading + if stage in [None, 'fit']: + if not self.model: + self.spec = f'[{self.spec[1:-1]} O2l{self.train_set.dataset.num_classes}]' + logger.info(f'Creating model {self.spec} with {self.train_set.dataset.num_classes} outputs') + nn = vgsl.TorchVGSLModel(self.spec) + if self.bounding_regions is not None: + nn.user_metadata['bounding_regions'] = self.bounding_regions + nn.user_metadata['topline'] = self.topline + self.nn = nn + else: + if self.train_set.dataset.class_mapping['baselines'].keys() != self.nn.user_metadata['class_mapping']['baselines'].keys() or \ + self.train_set.dataset.class_mapping['regions'].keys() != self.nn.user_metadata['class_mapping']['regions'].keys(): + + bl_diff = set(self.train_set.dataset.class_mapping['baselines'].keys()).symmetric_difference( + set(self.nn.user_metadata['class_mapping']['baselines'].keys())) + regions_diff = set(self.train_set.dataset.class_mapping['regions'].keys()).symmetric_difference( + set(self.nn.user_metadata['class_mapping']['regions'].keys())) + + if self.resize == 'fail': + raise ValueError(f'Training data and model class mapping differ (bl: {bl_diff}, regions: {regions_diff}') + elif self.resize == 'union': + new_bls = self.train_set.dataset.class_mapping['baselines'].keys() - self.nn.user_metadata['class_mapping']['baselines'].keys() + new_regions = self.train_set.dataset.class_mapping['regions'].keys() - self.nn.user_metadata['class_mapping']['regions'].keys() + cls_idx = max(max(self.nn.user_metadata['class_mapping']['baselines'].values()) if self.nn.user_metadata['class_mapping']['baselines'] else -1, # noqa + max(self.nn.user_metadata['class_mapping']['regions'].values()) if self.nn.user_metadata['class_mapping']['regions'] else -1) # noqa + logger.info(f'Adding {len(new_bls) + len(new_regions)} missing types to network output layer.') + self.nn.resize_output(cls_idx + len(new_bls) + len(new_regions) + 1) + for c in new_bls: + cls_idx += 1 + self.nn.user_metadata['class_mapping']['baselines'][c] = cls_idx + for c in new_regions: + cls_idx += 1 + self.nn.user_metadata['class_mapping']['regions'][c] = cls_idx + elif self.resize == 'new': + logger.info('Fitting network exactly to training set.') + new_bls = self.train_set.dataset.class_mapping['baselines'].keys() - self.nn.user_metadata['class_mapping']['baselines'].keys() + new_regions = self.train_set.dataset.class_mapping['regions'].keys() - self.nn.user_metadata['class_mapping']['regions'].keys() + del_bls = self.nn.user_metadata['class_mapping']['baselines'].keys() - self.train_set.dataset.class_mapping['baselines'].keys() + del_regions = self.nn.user_metadata['class_mapping']['regions'].keys() - self.train_set.dataset.class_mapping['regions'].keys() + + logger.info(f'Adding {len(new_bls) + len(new_regions)} missing ' + f'types and removing {len(del_bls) + len(del_regions)} to network output layer ') + cls_idx = max(max(self.nn.user_metadata['class_mapping']['baselines'].values()) if self.nn.user_metadata['class_mapping']['baselines'] else -1, # noqa + max(self.nn.user_metadata['class_mapping']['regions'].values()) if self.nn.user_metadata['class_mapping']['regions'] else -1) # noqa + + del_indices = [self.nn.user_metadata['class_mapping']['baselines'][x] for x in del_bls] + del_indices.extend(self.nn.user_metadata['class_mapping']['regions'][x] for x in del_regions) + self.nn.resize_output(cls_idx + len(new_bls) + len(new_regions) - + len(del_bls) - len(del_regions) + 1, del_indices) + + # delete old baseline/region types + cls_idx = min(min(self.nn.user_metadata['class_mapping']['baselines'].values()) if self.nn.user_metadata['class_mapping']['baselines'] else np.inf, # noqa + min(self.nn.user_metadata['class_mapping']['regions'].values()) if self.nn.user_metadata['class_mapping']['regions'] else np.inf) # noqa + + bls = {} + for k, v in sorted(self.nn.user_metadata['class_mapping']['baselines'].items(), key=lambda item: item[1]): + if k not in del_bls: + bls[k] = cls_idx + cls_idx += 1 + + regions = {} + for k, v in sorted(self.nn.user_metadata['class_mapping']['regions'].items(), key=lambda item: item[1]): + if k not in del_regions: + regions[k] = cls_idx + cls_idx += 1 + + self.nn.user_metadata['class_mapping']['baselines'] = bls + self.nn.user_metadata['class_mapping']['regions'] = regions + + # add new baseline/region types + cls_idx -= 1 + for c in new_bls: + cls_idx += 1 + self.nn.user_metadata['class_mapping']['baselines'][c] = cls_idx + for c in new_regions: + cls_idx += 1 + self.nn.user_metadata['class_mapping']['regions'][c] = cls_idx + else: + raise ValueError(f'invalid resize parameter value {self.resize}') + # backfill train_set/val_set mapping if key-equal as the actual + # numbering in the train_set might be different + self.train_set.dataset.class_mapping = self.nn.user_metadata['class_mapping'] + self.val_set.dataset.class_mapping = self.nn.user_metadata['class_mapping'] + + # updates model's hyper params with user-defined ones + self.nn.hyper_params = self.hparams.hyper_params + + # change topline/baseline switch + loc = {None: 'centerline', + True: 'topline', + False: 'baseline'} + + if 'topline' not in self.nn.user_metadata: + logger.warning(f'Setting baseline location to {loc[self.topline]} from unset model.') + elif self.nn.user_metadata['topline'] != self.topline: + from_loc = loc[self.nn.user_metadata['topline']] + logger.warning(f'Changing baseline location from {from_loc} to {loc[self.topline]}.') + self.nn.user_metadata['topline'] = self.topline + + logger.info('Training line types:') + for k, v in self.train_set.dataset.class_mapping['baselines'].items(): + logger.info(f' {k}\t{v}\t{self.train_set.dataset.class_stats["baselines"][k]}') + logger.info('Training region types:') + for k, v in self.train_set.dataset.class_mapping['regions'].items(): + logger.info(f' {k}\t{v}\t{self.train_set.dataset.class_stats["regions"][k]}') + + if len(self.train_set) == 0: + raise ValueError('No valid training data was provided to the train command. Please add valid XML data.') + + # set model type metadata field and dump class_mapping + self.nn.model_type = 'segmentation' + self.nn.user_metadata['class_mapping'] = self.val_set.dataset.class_mapping + + # for model size/trainable parameter output + self.net = self.nn.nn + + torch.set_num_threads(max(self.num_workers, 1)) + + # set up validation metrics after output classes have been determined + # baseline metrics + # region metrics + num_regions = len(self.val_set.dataset.class_mapping['regions']) + self.val_region_px_accuracy = MultilabelAccuracy(average='micro', num_labels=num_regions) + self.val_region_mean_accuracy = MultilabelAccuracy(average='macro', num_labels=num_regions) + self.val_region_mean_iu = MultilabelJaccardIndex(average='macro', num_labels=num_regions) + self.val_region_freq_iu = MultilabelJaccardIndex(average='weighted', num_labels=num_regions) + + def train_dataloader(self): + return DataLoader(self.train_set, + batch_size=1, + num_workers=self.num_workers, + shuffle=True, + pin_memory=True) + + def val_dataloader(self): + return DataLoader(self.val_set, + shuffle=False, + batch_size=1, + num_workers=self.num_workers, + pin_memory=True) + + def configure_callbacks(self): + callbacks = [] + if self.hparams.hyper_params['quit'] == 'early': + callbacks.append(EarlyStopping(monitor='val_mean_iu', + mode='max', + patience=self.hparams.hyper_params['lag'], + stopping_threshold=1.0)) + + return callbacks + + # configuration of optimizers and learning rate schedulers + # -------------------------------------------------------- + # + # All schedulers are created internally with a frequency of step to enable + # batch-wise learning rate warmup. In lr_scheduler_step() calls to the + # scheduler are then only performed at the end of the epoch. + def configure_optimizers(self): + return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, + self.nn.nn.parameters(), + len_train_set=len(self.train_set), + loss_tracking_mode='max') + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): + # update params + optimizer.step(closure=optimizer_closure) + + # linear warmup between 0 and the initial learning rate `lrate` in `warmup` + # steps. + if self.hparams.hyper_params['warmup'] and self.trainer.global_step < self.hparams.hyper_params['warmup']: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params['warmup']) + for pg in optimizer.param_groups: + pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] + + def lr_scheduler_step(self, scheduler, metric): + if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: + # step OneCycleLR each batch if not in warmup phase + if isinstance(scheduler, lr_scheduler.OneCycleLR): + scheduler.step() + # step every other scheduler epoch-wise + elif self.trainer.is_last_batch: + if metric is None: + scheduler.step() + else: + scheduler.step(metric) diff --git a/kraken/lib/train/trainer.py b/kraken/lib/train/trainer.py new file mode 100644 index 00000000..42f0996b --- /dev/null +++ b/kraken/lib/train/trainer.py @@ -0,0 +1,166 @@ +# +# Copyright 2015 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Training loop interception helpers +""" +import logging +import warnings +import lightning as L + +from typing import TYPE_CHECKING, Union, Optional +from lightning.pytorch.callbacks import (BaseFinetuning, Callback, + LearningRateMonitor) + +from kraken.lib import progress + +if TYPE_CHECKING: + from os import PathLike + +logger = logging.getLogger(__name__) + + +def _validation_worker_init_fn(worker_id): + """ Fix random seeds so that augmentation always produces the same + results when validating. Temporarily increase the logging level + for lightning because otherwise it will display a message + at info level about the seed being changed. """ + from lightning.pytorch import seed_everything + seed_everything(42) + + +class KrakenTrainer(L.Trainer): + def __init__(self, + enable_progress_bar: bool = True, + enable_summary: bool = True, + min_epochs: int = 5, + max_epochs: int = 100, + freeze_backbone=-1, + pl_logger: Union[L.pytorch.loggers.logger.Logger, str, None] = None, + log_dir: Optional['PathLike'] = None, + *args, + **kwargs): + kwargs['enable_checkpointing'] = False + kwargs['enable_progress_bar'] = enable_progress_bar + kwargs['min_epochs'] = min_epochs + kwargs['max_epochs'] = max_epochs + kwargs['callbacks'] = ([] if 'callbacks' not in kwargs else kwargs['callbacks']) + if not isinstance(kwargs['callbacks'], list): + kwargs['callbacks'] = [kwargs['callbacks']] + + if pl_logger: + if 'logger' in kwargs and isinstance(kwargs['logger'], L.pytorch.loggers.logger.Logger): + logger.debug('Experiment logger has been provided outside KrakenTrainer as `logger`') + elif isinstance(pl_logger, L.pytorch.loggers.logger.Logger): + logger.debug('Experiment logger has been provided outside KrakenTrainer as `pl_logger`') + kwargs['logger'] = pl_logger + elif pl_logger == 'tensorboard': + logger.debug('Creating default experiment logger') + kwargs['logger'] = L.pytorch.loggers.TensorBoardLogger(log_dir) + else: + logger.error('`pl_logger` was set, but %s is not an accepted value', pl_logger) + raise ValueError(f'{pl_logger} is not acceptable as logger') + kwargs['callbacks'].append(LearningRateMonitor(logging_interval='step')) + else: + kwargs['logger'] = False + + if enable_progress_bar: + progress_bar_cb = progress.KrakenTrainProgressBar(leave=True) + kwargs['callbacks'].append(progress_bar_cb) + + if enable_summary: + from lightning.pytorch.callbacks import RichModelSummary + summary_cb = RichModelSummary(max_depth=2) + kwargs['callbacks'].append(summary_cb) + kwargs['enable_model_summary'] = False + + if freeze_backbone > 0: + kwargs['callbacks'].append(KrakenFreezeBackbone(freeze_backbone)) + + kwargs['callbacks'].extend([KrakenSetOneChannelMode(), KrakenSaveModel()]) + super().__init__(*args, **kwargs) + self.automatic_optimization = False + + def fit(self, *args, **kwargs): + with warnings.catch_warnings(): + warnings.filterwarnings(action='ignore', category=UserWarning, + message='The dataloader,') + super().fit(*args, **kwargs) + + +class KrakenFreezeBackbone(BaseFinetuning): + """ + Callback freezing all but the last layer for fixed number of iterations. + """ + def __init__(self, unfreeze_at_iterations=10): + super().__init__() + self.unfreeze_at_iteration = unfreeze_at_iterations + + def freeze_before_training(self, pl_module): + pass + + def finetune_function(self, pl_module, current_epoch, optimizer): + pass + + def on_train_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: + self.freeze(pl_module.net[:-1]) + + def on_train_batch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch, batch_idx) -> None: + """ + Called for each training batch. + """ + if trainer.global_step == self.unfreeze_at_iteration: + for opt_idx, optimizer in enumerate(trainer.optimizers): + num_param_groups = len(optimizer.param_groups) + self.unfreeze_and_add_param_group(modules=pl_module.net[:-1], + optimizer=optimizer, + train_bn=True,) + current_param_groups = optimizer.param_groups + self._store(pl_module, opt_idx, num_param_groups, current_param_groups) + + def on_train_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: + """Called when the epoch begins.""" + pass + + +class KrakenSetOneChannelMode(Callback): + """ + Callback that sets the one_channel_mode of the model after the first epoch. + """ + def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: + # fill one_channel_mode after 1 iteration over training data set + if not trainer.sanity_checking and trainer.current_epoch == 0 and trainer.model.nn.model_type == 'recognition': + ds = getattr(pl_module, 'train_set', None) + if not ds and trainer.datamodule: + ds = trainer.datamodule.train_set + im_mode = ds.dataset.im_mode + if im_mode in ['1', 'L']: + logger.info(f'Setting model one_channel_mode to {im_mode}.') + trainer.model.nn.one_channel_mode = im_mode + + +class KrakenSaveModel(Callback): + """ + Kraken's own serialization callback instead of pytorch's. + """ + def on_validation_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: + if not trainer.sanity_checking: + trainer.model.nn.hyper_params['completed_epochs'] += 1 + metric = float(trainer.logged_metrics['val_metric']) if 'val_metric' in trainer.logged_metrics else -1.0 + trainer.model.nn.user_metadata['accuracy'].append((trainer.global_step, metric)) + trainer.model.nn.user_metadata['metrics'].append((trainer.global_step, {k: float(v) for k, v in trainer.logged_metrics.items()})) + + logger.info('Saving to {}_{}.mlmodel'.format(trainer.model.output, trainer.current_epoch)) + trainer.model.nn.save_model(f'{trainer.model.output}_{trainer.current_epoch}.mlmodel') + trainer.model.best_model = f'{trainer.model.output}_{trainer.model.best_epoch}.mlmodel' diff --git a/kraken/lib/train/utils.py b/kraken/lib/train/utils.py new file mode 100644 index 00000000..c1cd6c10 --- /dev/null +++ b/kraken/lib/train/utils.py @@ -0,0 +1,93 @@ +# +# Copyright 2015 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Model configuration helpers +""" +import logging + +import torch +from torch.optim import lr_scheduler + +logger = logging.getLogger(__name__) + + +def _configure_optimizer_and_lr_scheduler(hparams, params, len_train_set=None, loss_tracking_mode='max'): + optimizer = hparams.get("optimizer") + lrate = hparams.get("lrate") + momentum = hparams.get("momentum") + weight_decay = hparams.get("weight_decay") + schedule = hparams.get("schedule") + gamma = hparams.get("gamma") + cos_t_max = hparams.get("cos_t_max") + cos_min_lr = hparams.get("cos_min_lr") + step_size = hparams.get("step_size") + rop_factor = hparams.get("rop_factor") + rop_patience = hparams.get("rop_patience") + epochs = hparams.get("epochs") + completed_epochs = hparams.get("completed_epochs") + + # XXX: Warmup is not configured here because it needs to be manually done in optimizer_step() + logger.debug(f'Constructing {optimizer} optimizer (lr: {lrate}, momentum: {momentum})') + if optimizer == 'Adam': + optim = torch.optim.Adam(params, lr=lrate, weight_decay=weight_decay) + else: + optim = getattr(torch.optim, optimizer)(params, + lr=lrate, + momentum=momentum, + weight_decay=weight_decay) + lr_sched = {} + if schedule == 'exponential': + lr_sched = {'scheduler': lr_scheduler.ExponentialLR(optim, gamma, last_epoch=completed_epochs-1), + 'interval': 'step'} + elif schedule == 'cosine': + lr_sched = {'scheduler': lr_scheduler.CosineAnnealingLR(optim, + cos_t_max, + cos_min_lr, + last_epoch=completed_epochs-1), + 'interval': 'step'} + elif schedule == 'step': + lr_sched = {'scheduler': lr_scheduler.StepLR(optim, step_size, gamma, last_epoch=completed_epochs-1), + 'interval': 'step'} + elif schedule == 'reduceonplateau': + lr_sched = {'scheduler': lr_scheduler.ReduceLROnPlateau(optim, + mode=loss_tracking_mode, + factor=rop_factor, + patience=rop_patience), + 'interval': 'step'} + elif schedule == '1cycle': + if epochs <= 0: + raise ValueError('1cycle learning rate scheduler selected but ' + 'number of epochs is less than 0 ' + f'({epochs}).') + last_epoch = completed_epochs*len_train_set if completed_epochs else -1 + lr_sched = {'scheduler': lr_scheduler.OneCycleLR(optim, + max_lr=lrate, + epochs=epochs, + steps_per_epoch=len_train_set, + last_epoch=last_epoch), + 'interval': 'step'} + elif schedule != 'constant': + raise ValueError(f'Unsupported learning rate scheduler {schedule}.') + + ret = {'optimizer': optim} + if lr_sched: + ret['lr_scheduler'] = lr_sched + + if schedule == 'reduceonplateau': + lr_sched['monitor'] = 'val_metric' + lr_sched['strict'] = False + lr_sched['reduce_on_plateau'] = True + + return ret From 05fd70e93e374556b417cf16cb7f4087145f507a Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 27 Sep 2024 21:18:00 +0200 Subject: [PATCH 03/16] wip line vectorization/matching/distance in validation --- kraken/lib/dataset/segmentation.py | 62 +++++++++++++++--------------- kraken/lib/segmentation.py | 57 +++++++++++++++++++++++++++ kraken/lib/train/segmentation.py | 26 +++++++++---- 3 files changed, 107 insertions(+), 38 deletions(-) diff --git a/kraken/lib/dataset/segmentation.py b/kraken/lib/dataset/segmentation.py index a9f96222..0a288d5a 100644 --- a/kraken/lib/dataset/segmentation.py +++ b/kraken/lib/dataset/segmentation.py @@ -30,7 +30,7 @@ from torch.utils.data import Dataset from torchvision import transforms -from kraken.lib.segmentation import scale_regions +from kraken.lib.segmentation import scale_regions, to_curve if TYPE_CHECKING: from kraken.containers import Segmentation @@ -46,6 +46,25 @@ class BaselineSet(Dataset): """ Dataset for training a baseline/region segmentation model. + + Args: + line_width: Height of the baseline in the scaled input. + padding: Tuple of ints containing the left/right, top/bottom + padding of the input images. + target_size: Target size of the image as a (height, width) tuple. + augmentation: Enable/disable augmentation. + valid_baselines: Sequence of valid baseline identifiers. If `None` + all are valid. + merge_baselines: Sequence of baseline identifiers to merge. Note + that merging occurs after entities not in valid_* + have been discarded. + valid_regions: Sequence of valid region identifiers. If `None` all + are valid. + merge_regions: Sequence of region identifiers to merge. Note that + merging occurs after entities not in valid_* have + been discarded. + return_curves: Whether to return fitted Bézier curves in addition to + the pixel heatmaps. Used during validation. """ def __init__(self, line_width: int = 4, @@ -55,27 +74,8 @@ def __init__(self, valid_baselines: Sequence[str] = None, merge_baselines: Dict[str, Sequence[str]] = None, valid_regions: Sequence[str] = None, - merge_regions: Dict[str, Sequence[str]] = None): - """ - Creates a dataset for a text-line and region segmentation model. - - Args: - line_width: Height of the baseline in the scaled input. - padding: Tuple of ints containing the left/right, top/bottom - padding of the input images. - target_size: Target size of the image as a (height, width) tuple. - augmentation: Enable/disable augmentation. - valid_baselines: Sequence of valid baseline identifiers. If `None` - all are valid. - merge_baselines: Sequence of baseline identifiers to merge. Note - that merging occurs after entities not in valid_* - have been discarded. - valid_regions: Sequence of valid region identifiers. If `None` all - are valid. - merge_regions: Sequence of region identifiers to merge. Note that - merging occurs after entities not in valid_* have - been discarded. - """ + merge_regions: Dict[str, Sequence[str]] = None, + return_curves: bool = False): super().__init__() self.imgs = [] self.im_mode = '1' @@ -91,6 +91,7 @@ def __init__(self, self.mreg_dict = merge_regions if merge_regions is not None else {} self.valid_baselines = valid_baselines self.valid_regions = valid_regions + self.return_curves = return_curves self.aug = None if augmentation: @@ -162,16 +163,14 @@ def __getitem__(self, idx): try: logger.debug(f'Attempting to load {im}') im = Image.open(im) - im, target = self.transform(im, target) - return {'image': im, 'target': target} + return self.transform(im, target) except Exception: self.failed_samples.add(idx) idx = np.random.randint(0, len(self.imgs)) logger.debug(traceback.format_exc()) logger.info(f'Failed. Replacing with sample {idx}') return self[idx] - im, target = self.transform(im, target) - return {'image': im, 'target': target} + return self.transform(im, target) @staticmethod def _get_ortho_line(lineseg, point, line_width, offset): @@ -194,6 +193,7 @@ def transform(self, image, target): start_sep_cls = self.class_mapping['aux']['_start_separator'] end_sep_cls = self.class_mapping['aux']['_end_separator'] + curves = defaultdict(list) for key, lines in target['baselines'].items(): try: cls_idx = self.class_mapping['baselines'][key] @@ -202,9 +202,8 @@ def transform(self, image, target): continue for line in lines: # buffer out line to desired width - line = [k for k, g in groupby(line)] - line = np.array(line)*scale - shp_line = geom.LineString(line) + line = np.array([k for k, g in groupby(line)]) + shp_line = geom.LineString(line*scale) split_offset = min(5, shp_line.length/2) line_pol = np.array(shp_line.buffer(self.line_width/2, cap_style=2).boundary.coords, dtype=int) rr, cc = polygon(line_pol[:, 1], line_pol[:, 0], shape=image.shape[1:]) @@ -223,6 +222,9 @@ def transform(self, image, target): rr_s, cc_s = polygon(end_sep[:, 1], end_sep[:, 0], shape=image.shape[1:]) t[end_sep_cls, rr_s, cc_s] = 1 t[end_sep_cls, rr, cc] = 0 + # Bézier curve fitting + if self.return_curves: + curves[key].append(to_curve(line, orig_size)) for key, regions in target['regions'].items(): try: cls_idx = self.class_mapping['regions'][key] @@ -240,7 +242,7 @@ def transform(self, image, target): o = self.aug(image=image, mask=target) image = torch.tensor(o['image']).permute(2, 0, 1) target = torch.tensor(o['mask']).permute(2, 0, 1) - return image, target + return {'image': image, 'target': target, 'curves': dict(curves) if self.return_curves else None} def __len__(self): return len(self.imgs) diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 0eb7f0bd..c929546e 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -31,6 +31,7 @@ from scipy.signal import convolve2d from scipy.spatial.distance import pdist, squareform from shapely.ops import nearest_points, unary_union +from shapely.geometry import LineString from shapely.validation import explain_validity from skimage import draw, filters from skimage.filters import sobel @@ -40,6 +41,8 @@ from skimage.morphology import skeletonize from skimage.transform import (AffineTransform, PiecewiseAffineTransform, warp) +from scipy.special import comb + from kraken.lib import default_specs from kraken.lib.exceptions import KrakenInputException @@ -1385,3 +1388,57 @@ def extract_polygons(im: Image.Image, logger.error('bbox {} is outside of image bounds {}'.format(box, im.size)) raise KrakenInputException('Line outside of image bounds') yield im.crop(box).rotate(angle, expand=True), box + +### +# Bézier curve fitting +### + + +def Mtk(n, t, k): + return t**k * (1-t)**(n-k) * comb(n, k) + + +def BezierCoeff(ts): + return [[Mtk(3, t, k) for k in range(4)] for t in ts] + + +def bezier_fit(bl): + x = bl[:, 0] + y = bl[:, 1] + dy = y[1:] - y[:-1] + dx = x[1:] - x[:-1] + dt = (dx ** 2 + dy ** 2)**0.5 + t = dt/dt.sum() + t = np.hstack(([0], t)) + t = t.cumsum() + + Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9) + + control_points = Pseudoinverse.dot(bl) # (4,9)*(9,2) -> (4,2) + medi_ctp = control_points[1:-1, :] + return medi_ctp + + +def to_curve(baseline: torch.FloatTensor, + im_size: Tuple[int, int], + min_points: int = 8) -> torch.FloatTensor: + """ + Fits a polyline as a quadratic Bézier curve. + + Args: + baseline: tensor of shape (S, 2) with coordinates in x, y format. + im_size: image size (W, H) used for control point normalization. + min_points: Minimal number of points in the baseline. If the input + baseline contains less than `min_points` additional points + will be interpolated at regular intervals along the line. + + Returns: + Tensor of shape (8,) + """ + baseline = np.array(baseline) + if len(baseline) < min_points: + ls = LineString(baseline) + baseline = np.stack([np.array(ls.interpolate(x, normalized=True).coords)[0] for x in np.linspace(0, 1, 8)]) + curve = np.concatenate(([baseline[0]], bezier_fit(baseline), [baseline[-1]]))/im_size + curve = curve.flatten() + return torch.from_numpy(curve) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index cfaa3e2f..7cd1f403 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -28,6 +28,7 @@ Union) from torch.optim import lr_scheduler +from scipy.optimize import linear_sum_assignment from torch.utils.data import DataLoader, Subset, random_split from lightning.pytorch.callbacks import EarlyStopping from torchmetrics.classification import (MultilabelAccuracy, @@ -39,7 +40,7 @@ from kraken.lib.xml import XMLPage from kraken.lib.models import validate_hyper_parameters -from kraken.lib.segmentation import vectorize_lines +from kraken.lib.segmentation import vectorize_lines, to_curve from .utils import _configure_optimizer_and_lr_scheduler @@ -194,7 +195,8 @@ def __init__(self, valid_baselines=valid_baselines, merge_baselines=merge_baselines, valid_regions=valid_regions, - merge_regions=merge_regions) + merge_regions=merge_regions, + return_curves=True) for page in training_data: train_set.add(page) @@ -206,7 +208,8 @@ def __init__(self, valid_baselines=valid_baselines, merge_baselines=merge_baselines, valid_regions=valid_regions, - merge_regions=merge_regions) + merge_regions=merge_regions, + return_curves=True) for page in evaluation_data: val_set.add(page) @@ -246,7 +249,7 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): - x, y = batch['image'], batch['target'] + x, y, y_curves = batch['image'], batch['target'], batch['curves'] pred, _ = self.nn.nn(x) # scale target to output size y = F.interpolate(y, size=(pred.size(2), pred.size(3)), mode='nearest').int() @@ -258,12 +261,17 @@ def validation_step(self, batch, batch_idx): self.val_region_mean_accuracy.update(pred_reg, y_reg) self.val_region_mean_iu.update(pred_reg, y_reg) self.val_region_freq_iu.update(pred_reg, y_reg) - # vectorize lines st_sep = self.nn.user_metadata['class_mapping']['aux']['_start_separator'] end_sep = self.nn.user_metadata['class_mapping']['aux']['_end_separator'] - line_idxs = sorted(self.nn.user_metadata['class_mapping']['lines'].values()) - for line_idx in line_idxs: - pred_bl = vectorize_lines(pred[:, [st_sep, end_sep, line_idx], ...], text_direction='horizontal') + + + # vectorize and match lines + for line_cls, line_idx in self.nn.user_metadata['class_mapping']['lines'].items(): + pred_curves = torch.stack([to_curve(pred_bl, pred.shape[:2][-1]) for pred_bl in vectorize_lines(pred[:, [st_sep, end_sep, line_idx], ...], + text_direction='horizontal')]) + cost_curves = torch.cdist(pred_curves, y_curves[line_cls], p=1).view(len(pred_curves), -1).cpu() + row_ind, col_ind = linear_sum_assignment(cost_curves) + self.val_line_dist.update(cost_curves[row_ind, col_ind]) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: @@ -271,6 +279,7 @@ def on_validation_epoch_end(self): mean_accuracy = self.val_region_mean_accuracy.compute() mean_iu = self.val_region_mean_iu.compute() freq_iu = self.val_region_freq_iu.compute() + mean_line_dist = self.val_line_dist.compute() if mean_iu > self.best_metric: logger.debug(f'Updating best region metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') @@ -283,6 +292,7 @@ def on_validation_epoch_end(self): self.log('val_region_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_region_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_region_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_mean_line_dist', mean_line_dist, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) # reset metrics even if sanity checking From d53b4faf2f3ab8e3baf68e60a5cd8add04ab25fd Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 27 Sep 2024 21:40:45 +0200 Subject: [PATCH 04/16] fix transcription tests --- tests/test_transcribe.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 04477981..724dc04c 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -8,6 +8,7 @@ from lxml import etree from PIL import Image +from kraken.containers import Segmentation, BBoxLine from kraken.transcribe import TranscriptionInterface thisfile = Path(__file__).resolve().parent @@ -24,8 +25,16 @@ def test_transcription_generation(self): Tests creation of transcription interfaces with segmentation. """ tr = TranscriptionInterface() - with open(resources / 'segmentation.json') as fp: - seg = json.load(fp) + + + seg = Segmentation(type='bbox', + imagename = resources / 'bw.png', + lines=[BBoxLine(id='foo', + bbox=[200, 10, 400, 156])], + text_direction='horizontal-lr', + script_detection=False + ) + with Image.open(resources / 'input.jpg') as im: tr.add_page(im, seg) fp = BytesIO() From 30c41ec84a456e730377abc28bf84051413c9213 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 29 Sep 2024 04:04:50 +0200 Subject: [PATCH 05/16] data types, fallback, ... --- kraken/lib/dataset/segmentation.py | 6 ++++-- kraken/lib/segmentation.py | 10 ++++++---- kraken/lib/train/segmentation.py | 17 ++++++++++------- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/kraken/lib/dataset/segmentation.py b/kraken/lib/dataset/segmentation.py index 0a288d5a..3f89d346 100644 --- a/kraken/lib/dataset/segmentation.py +++ b/kraken/lib/dataset/segmentation.py @@ -202,7 +202,7 @@ def transform(self, image, target): continue for line in lines: # buffer out line to desired width - line = np.array([k for k, g in groupby(line)]) + line = np.array([k for k, g in groupby(line)], dtype=np.float32) shp_line = geom.LineString(line*scale) split_offset = min(5, shp_line.length/2) line_pol = np.array(shp_line.buffer(self.line_width/2, cap_style=2).boundary.coords, dtype=int) @@ -224,7 +224,9 @@ def transform(self, image, target): t[end_sep_cls, rr, cc] = 0 # Bézier curve fitting if self.return_curves: - curves[key].append(to_curve(line, orig_size)) + curves[key].append(to_curve(torch.from_numpy(line), orig_size)) + for k, v in curves.items(): + curves[k] = torch.stack(v) for key, regions in target['regions'].items(): try: cls_idx = self.class_mapping['regions'][key] diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index c929546e..69cc490d 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -52,6 +52,7 @@ _T_pil_or_np = TypeVar('_T_pil_or_np', Image.Image, np.ndarray) +_T_tensor_or_np = TypeVar('_T_tensor_or_np', torch.Tensor, np.ndarray) logger = logging.getLogger('kraken') @@ -1435,10 +1436,11 @@ def to_curve(baseline: torch.FloatTensor, Returns: Tensor of shape (8,) """ - baseline = np.array(baseline) if len(baseline) < min_points: ls = LineString(baseline) - baseline = np.stack([np.array(ls.interpolate(x, normalized=True).coords)[0] for x in np.linspace(0, 1, 8)]) - curve = np.concatenate(([baseline[0]], bezier_fit(baseline), [baseline[-1]]))/im_size + baseline = torch.stack([torch.tensor(ls.interpolate(x, normalized=True).coords)[0] for x in np.linspace(0, 1, 8)]) + baseline = baseline.numpy() + curve = np.concatenate(([baseline[0]], bezier_fit(baseline), [baseline[-1]])) + curve = curve/im_size curve = curve.flatten() - return torch.from_numpy(curve) + return torch.from_numpy(curve.astype(baseline.dtype)) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 7cd1f403..08c1692c 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -264,14 +264,17 @@ def validation_step(self, batch, batch_idx): st_sep = self.nn.user_metadata['class_mapping']['aux']['_start_separator'] end_sep = self.nn.user_metadata['class_mapping']['aux']['_end_separator'] - # vectorize and match lines - for line_cls, line_idx in self.nn.user_metadata['class_mapping']['lines'].items(): - pred_curves = torch.stack([to_curve(pred_bl, pred.shape[:2][-1]) for pred_bl in vectorize_lines(pred[:, [st_sep, end_sep, line_idx], ...], - text_direction='horizontal')]) - cost_curves = torch.cdist(pred_curves, y_curves[line_cls], p=1).view(len(pred_curves), -1).cpu() - row_ind, col_ind = linear_sum_assignment(cost_curves) - self.val_line_dist.update(cost_curves[row_ind, col_ind]) + for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items(): + pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].numpy(), text_direction='horizontal') + pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl] + if pred_curves: + pred_curves = torch.stack(pred_curves) + cost_curves = torch.cdist(pred_curves, y_curves[line_cls][0], p=1).cpu() + row_ind, col_ind = linear_sum_assignment(cost_curves) + self.val_line_dist.update(cost_curves[row_ind, col_ind]/8.0) + else: + self.val_line_dist.update(torch.ones(len(y_curves[line_cls][0]))) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: From fee7f13062ced5bec3c8e61dd4e70581ee1c438b Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 29 Sep 2024 14:54:27 +0200 Subject: [PATCH 06/16] add cardinality penalty term to baseline val metric --- kraken/lib/train/segmentation.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 08c1692c..6325e7b7 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -271,8 +271,15 @@ def validation_step(self, batch, batch_idx): if pred_curves: pred_curves = torch.stack(pred_curves) cost_curves = torch.cdist(pred_curves, y_curves[line_cls][0], p=1).cpu() - row_ind, col_ind = linear_sum_assignment(cost_curves) - self.val_line_dist.update(cost_curves[row_ind, col_ind]/8.0) + costs = cost_curves = [linear_sum_assignment(cost_curves)] + # num of predictions differs from target -> take n best + # predictions and add error penalty term for the rest. + if diff := abs(len(pred_curves) - len(y_curves[line_cls][0])): + costs = np.sort(costs)[:len(y_curves[line_cls][0])] + penalty = np.full(diff, 8.0) + costs = np.concatenate([costs, penalty]) + self.val_line_dist.update(costs/8.0) + # no line output else: self.val_line_dist.update(torch.ones(len(y_curves[line_cls][0]))) From a97cf348087d8369deabc04d229454441fd3c08a Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 29 Sep 2024 16:07:40 +0200 Subject: [PATCH 07/16] Create metric on model --- kraken/lib/train/segmentation.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 6325e7b7..dcb9d876 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -31,6 +31,7 @@ from scipy.optimize import linear_sum_assignment from torch.utils.data import DataLoader, Subset, random_split from lightning.pytorch.callbacks import EarlyStopping +from torchmetrics.aggregation import MeanMetric from torchmetrics.classification import (MultilabelAccuracy, MultilabelJaccardIndex) @@ -278,10 +279,10 @@ def validation_step(self, batch, batch_idx): costs = np.sort(costs)[:len(y_curves[line_cls][0])] penalty = np.full(diff, 8.0) costs = np.concatenate([costs, penalty]) - self.val_line_dist.update(costs/8.0) + self.val_line_mean_dist.update(costs/8.0) # no line output else: - self.val_line_dist.update(torch.ones(len(y_curves[line_cls][0]))) + self.val_line_mean_dist.update(torch.ones(len(y_curves[line_cls][0]))) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: @@ -289,7 +290,7 @@ def on_validation_epoch_end(self): mean_accuracy = self.val_region_mean_accuracy.compute() mean_iu = self.val_region_mean_iu.compute() freq_iu = self.val_region_freq_iu.compute() - mean_line_dist = self.val_line_dist.compute() + line_mean_dist = self.val_line_mean_dist.compute() if mean_iu > self.best_metric: logger.debug(f'Updating best region metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') @@ -302,7 +303,7 @@ def on_validation_epoch_end(self): self.log('val_region_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_region_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_region_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_mean_line_dist', mean_line_dist, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_line_mean_dist', line_mean_dist, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) # reset metrics even if sanity checking @@ -432,7 +433,8 @@ def setup(self, stage: Optional[str] = None): torch.set_num_threads(max(self.num_workers, 1)) # set up validation metrics after output classes have been determined - # baseline metrics + # baseline metric + self.val_line_mean_dist = MeanMetric() # region metrics num_regions = len(self.val_set.dataset.class_mapping['regions']) self.val_region_px_accuracy = MultilabelAccuracy(average='micro', num_labels=num_regions) From 1b577c22c6d976043a22f9864e2b41eb4c7f0248 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Sep 2024 00:55:45 +0200 Subject: [PATCH 08/16] copy tensor to cpu --- kraken/lib/train/segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index dcb9d876..a6b5d2a0 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -267,7 +267,7 @@ def validation_step(self, batch, batch_idx): # vectorize and match lines for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items(): - pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].numpy(), text_direction='horizontal') + pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].cpu().numpy(), text_direction='horizontal') pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl] if pred_curves: pred_curves = torch.stack(pred_curves) From cda9b181d8c44bd0525bfe0fbfcb1e534e2c1826 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Sep 2024 12:01:43 +0200 Subject: [PATCH 09/16] put costs on correct device --- kraken/lib/train/segmentation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index a6b5d2a0..b68fd79a 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -279,10 +279,11 @@ def validation_step(self, batch, batch_idx): costs = np.sort(costs)[:len(y_curves[line_cls][0])] penalty = np.full(diff, 8.0) costs = np.concatenate([costs, penalty]) - self.val_line_mean_dist.update(costs/8.0) + costs = costs/8.0 # no line output else: - self.val_line_mean_dist.update(torch.ones(len(y_curves[line_cls][0]))) + costs = torch.ones(len(y_curves[line_cls][0])) + self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device)) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: From 86cd61ef78eea01fa9b0754d5ad45138c3634aae Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Sep 2024 12:12:00 +0200 Subject: [PATCH 10/16] deal with no pred no target case --- kraken/lib/train/segmentation.py | 35 ++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index b68fd79a..6f071ab4 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -269,21 +269,26 @@ def validation_step(self, batch, batch_idx): for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items(): pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].cpu().numpy(), text_direction='horizontal') pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl] - if pred_curves: - pred_curves = torch.stack(pred_curves) - cost_curves = torch.cdist(pred_curves, y_curves[line_cls][0], p=1).cpu() - costs = cost_curves = [linear_sum_assignment(cost_curves)] - # num of predictions differs from target -> take n best - # predictions and add error penalty term for the rest. - if diff := abs(len(pred_curves) - len(y_curves[line_cls][0])): - costs = np.sort(costs)[:len(y_curves[line_cls][0])] - penalty = np.full(diff, 8.0) - costs = np.concatenate([costs, penalty]) - costs = costs/8.0 - # no line output - else: - costs = torch.ones(len(y_curves[line_cls][0])) - self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device)) + if line_cls in y_curves: + target_curves = y_curves[line_cls][0] + if pred_curves: + pred_curves = torch.stack(pred_curves) + cost_curves = torch.cdist(pred_curves, target_curves, p=1).cpu() + costs = cost_curves = [linear_sum_assignment(cost_curves)] + # num of predictions differs from target -> take n best + # predictions and add error penalty term for the rest. + if diff := abs(len(pred_curves) - len(target_curves)): + costs = np.sort(costs)[:len(target_curves)] + penalty = np.full(diff, 8.0) + costs = np.concatenate([costs, penalty]) + costs = costs/8.0 + # no line output + else: + costs = torch.ones(len(target_curves)) + self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device)) + elif pred_curves: + costs = torch.ones(len(pred_curves)) + self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device)) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: From f2199c8951d01fa722b1d46cc1472cb844dc9174 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Sep 2024 12:47:11 +0200 Subject: [PATCH 11/16] Make sure preds are compatible with vectorization --- kraken/lib/train/segmentation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 6f071ab4..39b86b13 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -265,9 +265,13 @@ def validation_step(self, batch, batch_idx): st_sep = self.nn.user_metadata['class_mapping']['aux']['_start_separator'] end_sep = self.nn.user_metadata['class_mapping']['aux']['_end_separator'] + # cast pred/targets to float32 and move to CPU + pred = pred.cpu().float() + y_curves = y_curves.cpu() + # vectorize and match lines for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items(): - pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].cpu().numpy(), text_direction='horizontal') + pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].numpy(), text_direction='horizontal') pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl] if line_cls in y_curves: target_curves = y_curves[line_cls][0] From 6056146701def08e82a969b6a3c32964bacec4f6 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Sep 2024 12:54:13 +0200 Subject: [PATCH 12/16] dict-moving --- kraken/lib/train/segmentation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 39b86b13..4dfa610e 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -267,7 +267,8 @@ def validation_step(self, batch, batch_idx): # cast pred/targets to float32 and move to CPU pred = pred.cpu().float() - y_curves = y_curves.cpu() + for k, v in y_curves.items(): + y_curves[k] = v.cpu() # vectorize and match lines for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items(): From 3561357e9c80f4699acea9cfd207441dab83f3bf Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Sep 2024 14:26:34 +0200 Subject: [PATCH 13/16] typo in min cost selection --- kraken/lib/train/segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 4dfa610e..48b3e30c 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -279,7 +279,7 @@ def validation_step(self, batch, batch_idx): if pred_curves: pred_curves = torch.stack(pred_curves) cost_curves = torch.cdist(pred_curves, target_curves, p=1).cpu() - costs = cost_curves = [linear_sum_assignment(cost_curves)] + costs = cost_curves[linear_sum_assignment(cost_curves)] # num of predictions differs from target -> take n best # predictions and add error penalty term for the rest. if diff := abs(len(pred_curves) - len(target_curves)): From 4d370d02de8552d6e8f0358dc0f405d777a3877b Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Sep 2024 15:53:54 +0200 Subject: [PATCH 14/16] ndarray tensor mixup --- kraken/lib/train/segmentation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 48b3e30c..5ce68c99 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -283,9 +283,9 @@ def validation_step(self, batch, batch_idx): # num of predictions differs from target -> take n best # predictions and add error penalty term for the rest. if diff := abs(len(pred_curves) - len(target_curves)): - costs = np.sort(costs)[:len(target_curves)] - penalty = np.full(diff, 8.0) - costs = np.concatenate([costs, penalty]) + costs = torch.sort(costs)[:len(target_curves)] + penalty = torch.full((diff,), 8.0) + costs = torch.cat([costs, penalty]) costs = costs/8.0 # no line output else: From 4845ac001be55cfb94cf94002f6d7aa76e59548c Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Sep 2024 17:02:51 +0200 Subject: [PATCH 15/16] torch.sort works differently than np.sort --- kraken/lib/train/segmentation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 5ce68c99..8b0486e7 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -283,7 +283,8 @@ def validation_step(self, batch, batch_idx): # num of predictions differs from target -> take n best # predictions and add error penalty term for the rest. if diff := abs(len(pred_curves) - len(target_curves)): - costs = torch.sort(costs)[:len(target_curves)] + costs, _ = torch.sort(costs) + costs = costs[:len(target_curves)] penalty = torch.full((diff,), 8.0) costs = torch.cat([costs, penalty]) costs = costs/8.0 From 29314c6448721bac474dc06a97f5092114ed9336 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Sep 2024 18:18:08 +0200 Subject: [PATCH 16/16] cast vectorizer output to tensor --- kraken/lib/train/segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 8b0486e7..5ef8cd76 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -272,7 +272,7 @@ def validation_step(self, batch, batch_idx): # vectorize and match lines for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items(): - pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].numpy(), text_direction='horizontal') + pred_bl = [torch.tensor(x) for x in vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].numpy(), text_direction='horizontal')] pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl] if line_cls in y_curves: target_curves = y_curves[line_cls][0]