From 05fd70e93e374556b417cf16cb7f4087145f507a Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 27 Sep 2024 21:18:00 +0200 Subject: [PATCH] wip line vectorization/matching/distance in validation --- kraken/lib/dataset/segmentation.py | 62 +++++++++++++++--------------- kraken/lib/segmentation.py | 57 +++++++++++++++++++++++++++ kraken/lib/train/segmentation.py | 26 +++++++++---- 3 files changed, 107 insertions(+), 38 deletions(-) diff --git a/kraken/lib/dataset/segmentation.py b/kraken/lib/dataset/segmentation.py index a9f962226..0a288d5ac 100644 --- a/kraken/lib/dataset/segmentation.py +++ b/kraken/lib/dataset/segmentation.py @@ -30,7 +30,7 @@ from torch.utils.data import Dataset from torchvision import transforms -from kraken.lib.segmentation import scale_regions +from kraken.lib.segmentation import scale_regions, to_curve if TYPE_CHECKING: from kraken.containers import Segmentation @@ -46,6 +46,25 @@ class BaselineSet(Dataset): """ Dataset for training a baseline/region segmentation model. + + Args: + line_width: Height of the baseline in the scaled input. + padding: Tuple of ints containing the left/right, top/bottom + padding of the input images. + target_size: Target size of the image as a (height, width) tuple. + augmentation: Enable/disable augmentation. + valid_baselines: Sequence of valid baseline identifiers. If `None` + all are valid. + merge_baselines: Sequence of baseline identifiers to merge. Note + that merging occurs after entities not in valid_* + have been discarded. + valid_regions: Sequence of valid region identifiers. If `None` all + are valid. + merge_regions: Sequence of region identifiers to merge. Note that + merging occurs after entities not in valid_* have + been discarded. + return_curves: Whether to return fitted Bézier curves in addition to + the pixel heatmaps. Used during validation. """ def __init__(self, line_width: int = 4, @@ -55,27 +74,8 @@ def __init__(self, valid_baselines: Sequence[str] = None, merge_baselines: Dict[str, Sequence[str]] = None, valid_regions: Sequence[str] = None, - merge_regions: Dict[str, Sequence[str]] = None): - """ - Creates a dataset for a text-line and region segmentation model. - - Args: - line_width: Height of the baseline in the scaled input. - padding: Tuple of ints containing the left/right, top/bottom - padding of the input images. - target_size: Target size of the image as a (height, width) tuple. - augmentation: Enable/disable augmentation. - valid_baselines: Sequence of valid baseline identifiers. If `None` - all are valid. - merge_baselines: Sequence of baseline identifiers to merge. Note - that merging occurs after entities not in valid_* - have been discarded. - valid_regions: Sequence of valid region identifiers. If `None` all - are valid. - merge_regions: Sequence of region identifiers to merge. Note that - merging occurs after entities not in valid_* have - been discarded. - """ + merge_regions: Dict[str, Sequence[str]] = None, + return_curves: bool = False): super().__init__() self.imgs = [] self.im_mode = '1' @@ -91,6 +91,7 @@ def __init__(self, self.mreg_dict = merge_regions if merge_regions is not None else {} self.valid_baselines = valid_baselines self.valid_regions = valid_regions + self.return_curves = return_curves self.aug = None if augmentation: @@ -162,16 +163,14 @@ def __getitem__(self, idx): try: logger.debug(f'Attempting to load {im}') im = Image.open(im) - im, target = self.transform(im, target) - return {'image': im, 'target': target} + return self.transform(im, target) except Exception: self.failed_samples.add(idx) idx = np.random.randint(0, len(self.imgs)) logger.debug(traceback.format_exc()) logger.info(f'Failed. Replacing with sample {idx}') return self[idx] - im, target = self.transform(im, target) - return {'image': im, 'target': target} + return self.transform(im, target) @staticmethod def _get_ortho_line(lineseg, point, line_width, offset): @@ -194,6 +193,7 @@ def transform(self, image, target): start_sep_cls = self.class_mapping['aux']['_start_separator'] end_sep_cls = self.class_mapping['aux']['_end_separator'] + curves = defaultdict(list) for key, lines in target['baselines'].items(): try: cls_idx = self.class_mapping['baselines'][key] @@ -202,9 +202,8 @@ def transform(self, image, target): continue for line in lines: # buffer out line to desired width - line = [k for k, g in groupby(line)] - line = np.array(line)*scale - shp_line = geom.LineString(line) + line = np.array([k for k, g in groupby(line)]) + shp_line = geom.LineString(line*scale) split_offset = min(5, shp_line.length/2) line_pol = np.array(shp_line.buffer(self.line_width/2, cap_style=2).boundary.coords, dtype=int) rr, cc = polygon(line_pol[:, 1], line_pol[:, 0], shape=image.shape[1:]) @@ -223,6 +222,9 @@ def transform(self, image, target): rr_s, cc_s = polygon(end_sep[:, 1], end_sep[:, 0], shape=image.shape[1:]) t[end_sep_cls, rr_s, cc_s] = 1 t[end_sep_cls, rr, cc] = 0 + # Bézier curve fitting + if self.return_curves: + curves[key].append(to_curve(line, orig_size)) for key, regions in target['regions'].items(): try: cls_idx = self.class_mapping['regions'][key] @@ -240,7 +242,7 @@ def transform(self, image, target): o = self.aug(image=image, mask=target) image = torch.tensor(o['image']).permute(2, 0, 1) target = torch.tensor(o['mask']).permute(2, 0, 1) - return image, target + return {'image': image, 'target': target, 'curves': dict(curves) if self.return_curves else None} def __len__(self): return len(self.imgs) diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 0eb7f0bdc..c929546e4 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -31,6 +31,7 @@ from scipy.signal import convolve2d from scipy.spatial.distance import pdist, squareform from shapely.ops import nearest_points, unary_union +from shapely.geometry import LineString from shapely.validation import explain_validity from skimage import draw, filters from skimage.filters import sobel @@ -40,6 +41,8 @@ from skimage.morphology import skeletonize from skimage.transform import (AffineTransform, PiecewiseAffineTransform, warp) +from scipy.special import comb + from kraken.lib import default_specs from kraken.lib.exceptions import KrakenInputException @@ -1385,3 +1388,57 @@ def extract_polygons(im: Image.Image, logger.error('bbox {} is outside of image bounds {}'.format(box, im.size)) raise KrakenInputException('Line outside of image bounds') yield im.crop(box).rotate(angle, expand=True), box + +### +# Bézier curve fitting +### + + +def Mtk(n, t, k): + return t**k * (1-t)**(n-k) * comb(n, k) + + +def BezierCoeff(ts): + return [[Mtk(3, t, k) for k in range(4)] for t in ts] + + +def bezier_fit(bl): + x = bl[:, 0] + y = bl[:, 1] + dy = y[1:] - y[:-1] + dx = x[1:] - x[:-1] + dt = (dx ** 2 + dy ** 2)**0.5 + t = dt/dt.sum() + t = np.hstack(([0], t)) + t = t.cumsum() + + Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9) + + control_points = Pseudoinverse.dot(bl) # (4,9)*(9,2) -> (4,2) + medi_ctp = control_points[1:-1, :] + return medi_ctp + + +def to_curve(baseline: torch.FloatTensor, + im_size: Tuple[int, int], + min_points: int = 8) -> torch.FloatTensor: + """ + Fits a polyline as a quadratic Bézier curve. + + Args: + baseline: tensor of shape (S, 2) with coordinates in x, y format. + im_size: image size (W, H) used for control point normalization. + min_points: Minimal number of points in the baseline. If the input + baseline contains less than `min_points` additional points + will be interpolated at regular intervals along the line. + + Returns: + Tensor of shape (8,) + """ + baseline = np.array(baseline) + if len(baseline) < min_points: + ls = LineString(baseline) + baseline = np.stack([np.array(ls.interpolate(x, normalized=True).coords)[0] for x in np.linspace(0, 1, 8)]) + curve = np.concatenate(([baseline[0]], bezier_fit(baseline), [baseline[-1]]))/im_size + curve = curve.flatten() + return torch.from_numpy(curve) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index cfaa3e2fa..7cd1f4031 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -28,6 +28,7 @@ Union) from torch.optim import lr_scheduler +from scipy.optimize import linear_sum_assignment from torch.utils.data import DataLoader, Subset, random_split from lightning.pytorch.callbacks import EarlyStopping from torchmetrics.classification import (MultilabelAccuracy, @@ -39,7 +40,7 @@ from kraken.lib.xml import XMLPage from kraken.lib.models import validate_hyper_parameters -from kraken.lib.segmentation import vectorize_lines +from kraken.lib.segmentation import vectorize_lines, to_curve from .utils import _configure_optimizer_and_lr_scheduler @@ -194,7 +195,8 @@ def __init__(self, valid_baselines=valid_baselines, merge_baselines=merge_baselines, valid_regions=valid_regions, - merge_regions=merge_regions) + merge_regions=merge_regions, + return_curves=True) for page in training_data: train_set.add(page) @@ -206,7 +208,8 @@ def __init__(self, valid_baselines=valid_baselines, merge_baselines=merge_baselines, valid_regions=valid_regions, - merge_regions=merge_regions) + merge_regions=merge_regions, + return_curves=True) for page in evaluation_data: val_set.add(page) @@ -246,7 +249,7 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): - x, y = batch['image'], batch['target'] + x, y, y_curves = batch['image'], batch['target'], batch['curves'] pred, _ = self.nn.nn(x) # scale target to output size y = F.interpolate(y, size=(pred.size(2), pred.size(3)), mode='nearest').int() @@ -258,12 +261,17 @@ def validation_step(self, batch, batch_idx): self.val_region_mean_accuracy.update(pred_reg, y_reg) self.val_region_mean_iu.update(pred_reg, y_reg) self.val_region_freq_iu.update(pred_reg, y_reg) - # vectorize lines st_sep = self.nn.user_metadata['class_mapping']['aux']['_start_separator'] end_sep = self.nn.user_metadata['class_mapping']['aux']['_end_separator'] - line_idxs = sorted(self.nn.user_metadata['class_mapping']['lines'].values()) - for line_idx in line_idxs: - pred_bl = vectorize_lines(pred[:, [st_sep, end_sep, line_idx], ...], text_direction='horizontal') + + + # vectorize and match lines + for line_cls, line_idx in self.nn.user_metadata['class_mapping']['lines'].items(): + pred_curves = torch.stack([to_curve(pred_bl, pred.shape[:2][-1]) for pred_bl in vectorize_lines(pred[:, [st_sep, end_sep, line_idx], ...], + text_direction='horizontal')]) + cost_curves = torch.cdist(pred_curves, y_curves[line_cls], p=1).view(len(pred_curves), -1).cpu() + row_ind, col_ind = linear_sum_assignment(cost_curves) + self.val_line_dist.update(cost_curves[row_ind, col_ind]) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: @@ -271,6 +279,7 @@ def on_validation_epoch_end(self): mean_accuracy = self.val_region_mean_accuracy.compute() mean_iu = self.val_region_mean_iu.compute() freq_iu = self.val_region_freq_iu.compute() + mean_line_dist = self.val_line_dist.compute() if mean_iu > self.best_metric: logger.debug(f'Updating best region metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})') @@ -283,6 +292,7 @@ def on_validation_epoch_end(self): self.log('val_region_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_region_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_region_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_mean_line_dist', mean_line_dist, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) # reset metrics even if sanity checking