Skip to content

Commit

Permalink
wip line vectorization/matching/distance in validation
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Sep 27, 2024
1 parent 8ff556e commit 05fd70e
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 38 deletions.
62 changes: 32 additions & 30 deletions kraken/lib/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torch.utils.data import Dataset
from torchvision import transforms

from kraken.lib.segmentation import scale_regions
from kraken.lib.segmentation import scale_regions, to_curve

if TYPE_CHECKING:
from kraken.containers import Segmentation
Expand All @@ -46,6 +46,25 @@
class BaselineSet(Dataset):
"""
Dataset for training a baseline/region segmentation model.
Args:
line_width: Height of the baseline in the scaled input.
padding: Tuple of ints containing the left/right, top/bottom
padding of the input images.
target_size: Target size of the image as a (height, width) tuple.
augmentation: Enable/disable augmentation.
valid_baselines: Sequence of valid baseline identifiers. If `None`
all are valid.
merge_baselines: Sequence of baseline identifiers to merge. Note
that merging occurs after entities not in valid_*
have been discarded.
valid_regions: Sequence of valid region identifiers. If `None` all
are valid.
merge_regions: Sequence of region identifiers to merge. Note that
merging occurs after entities not in valid_* have
been discarded.
return_curves: Whether to return fitted Bézier curves in addition to
the pixel heatmaps. Used during validation.
"""
def __init__(self,
line_width: int = 4,
Expand All @@ -55,27 +74,8 @@ def __init__(self,
valid_baselines: Sequence[str] = None,
merge_baselines: Dict[str, Sequence[str]] = None,
valid_regions: Sequence[str] = None,
merge_regions: Dict[str, Sequence[str]] = None):
"""
Creates a dataset for a text-line and region segmentation model.
Args:
line_width: Height of the baseline in the scaled input.
padding: Tuple of ints containing the left/right, top/bottom
padding of the input images.
target_size: Target size of the image as a (height, width) tuple.
augmentation: Enable/disable augmentation.
valid_baselines: Sequence of valid baseline identifiers. If `None`
all are valid.
merge_baselines: Sequence of baseline identifiers to merge. Note
that merging occurs after entities not in valid_*
have been discarded.
valid_regions: Sequence of valid region identifiers. If `None` all
are valid.
merge_regions: Sequence of region identifiers to merge. Note that
merging occurs after entities not in valid_* have
been discarded.
"""
merge_regions: Dict[str, Sequence[str]] = None,
return_curves: bool = False):
super().__init__()
self.imgs = []
self.im_mode = '1'
Expand All @@ -91,6 +91,7 @@ def __init__(self,
self.mreg_dict = merge_regions if merge_regions is not None else {}
self.valid_baselines = valid_baselines
self.valid_regions = valid_regions
self.return_curves = return_curves

self.aug = None
if augmentation:
Expand Down Expand Up @@ -162,16 +163,14 @@ def __getitem__(self, idx):
try:
logger.debug(f'Attempting to load {im}')
im = Image.open(im)
im, target = self.transform(im, target)
return {'image': im, 'target': target}
return self.transform(im, target)
except Exception:
self.failed_samples.add(idx)
idx = np.random.randint(0, len(self.imgs))
logger.debug(traceback.format_exc())
logger.info(f'Failed. Replacing with sample {idx}')
return self[idx]
im, target = self.transform(im, target)
return {'image': im, 'target': target}
return self.transform(im, target)

@staticmethod
def _get_ortho_line(lineseg, point, line_width, offset):
Expand All @@ -194,6 +193,7 @@ def transform(self, image, target):
start_sep_cls = self.class_mapping['aux']['_start_separator']
end_sep_cls = self.class_mapping['aux']['_end_separator']

curves = defaultdict(list)
for key, lines in target['baselines'].items():
try:
cls_idx = self.class_mapping['baselines'][key]
Expand All @@ -202,9 +202,8 @@ def transform(self, image, target):
continue
for line in lines:
# buffer out line to desired width
line = [k for k, g in groupby(line)]
line = np.array(line)*scale
shp_line = geom.LineString(line)
line = np.array([k for k, g in groupby(line)])
shp_line = geom.LineString(line*scale)
split_offset = min(5, shp_line.length/2)
line_pol = np.array(shp_line.buffer(self.line_width/2, cap_style=2).boundary.coords, dtype=int)
rr, cc = polygon(line_pol[:, 1], line_pol[:, 0], shape=image.shape[1:])
Expand All @@ -223,6 +222,9 @@ def transform(self, image, target):
rr_s, cc_s = polygon(end_sep[:, 1], end_sep[:, 0], shape=image.shape[1:])
t[end_sep_cls, rr_s, cc_s] = 1
t[end_sep_cls, rr, cc] = 0
# Bézier curve fitting
if self.return_curves:
curves[key].append(to_curve(line, orig_size))
for key, regions in target['regions'].items():
try:
cls_idx = self.class_mapping['regions'][key]
Expand All @@ -240,7 +242,7 @@ def transform(self, image, target):
o = self.aug(image=image, mask=target)
image = torch.tensor(o['image']).permute(2, 0, 1)
target = torch.tensor(o['mask']).permute(2, 0, 1)
return image, target
return {'image': image, 'target': target, 'curves': dict(curves) if self.return_curves else None}

def __len__(self):
return len(self.imgs)
57 changes: 57 additions & 0 deletions kraken/lib/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from scipy.signal import convolve2d
from scipy.spatial.distance import pdist, squareform
from shapely.ops import nearest_points, unary_union
from shapely.geometry import LineString
from shapely.validation import explain_validity
from skimage import draw, filters
from skimage.filters import sobel
Expand All @@ -40,6 +41,8 @@
from skimage.morphology import skeletonize
from skimage.transform import (AffineTransform, PiecewiseAffineTransform, warp)

from scipy.special import comb

from kraken.lib import default_specs
from kraken.lib.exceptions import KrakenInputException

Expand Down Expand Up @@ -1385,3 +1388,57 @@ def extract_polygons(im: Image.Image,
logger.error('bbox {} is outside of image bounds {}'.format(box, im.size))
raise KrakenInputException('Line outside of image bounds')
yield im.crop(box).rotate(angle, expand=True), box

###
# Bézier curve fitting
###


def Mtk(n, t, k):
return t**k * (1-t)**(n-k) * comb(n, k)


def BezierCoeff(ts):
return [[Mtk(3, t, k) for k in range(4)] for t in ts]


def bezier_fit(bl):
x = bl[:, 0]
y = bl[:, 1]
dy = y[1:] - y[:-1]
dx = x[1:] - x[:-1]
dt = (dx ** 2 + dy ** 2)**0.5
t = dt/dt.sum()
t = np.hstack(([0], t))
t = t.cumsum()

Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9)

control_points = Pseudoinverse.dot(bl) # (4,9)*(9,2) -> (4,2)
medi_ctp = control_points[1:-1, :]
return medi_ctp


def to_curve(baseline: torch.FloatTensor,
im_size: Tuple[int, int],
min_points: int = 8) -> torch.FloatTensor:
"""
Fits a polyline as a quadratic Bézier curve.
Args:
baseline: tensor of shape (S, 2) with coordinates in x, y format.
im_size: image size (W, H) used for control point normalization.
min_points: Minimal number of points in the baseline. If the input
baseline contains less than `min_points` additional points
will be interpolated at regular intervals along the line.
Returns:
Tensor of shape (8,)
"""
baseline = np.array(baseline)
if len(baseline) < min_points:
ls = LineString(baseline)
baseline = np.stack([np.array(ls.interpolate(x, normalized=True).coords)[0] for x in np.linspace(0, 1, 8)])
curve = np.concatenate(([baseline[0]], bezier_fit(baseline), [baseline[-1]]))/im_size
curve = curve.flatten()
return torch.from_numpy(curve)
26 changes: 18 additions & 8 deletions kraken/lib/train/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Union)

from torch.optim import lr_scheduler
from scipy.optimize import linear_sum_assignment
from torch.utils.data import DataLoader, Subset, random_split
from lightning.pytorch.callbacks import EarlyStopping
from torchmetrics.classification import (MultilabelAccuracy,
Expand All @@ -39,7 +40,7 @@

from kraken.lib.xml import XMLPage
from kraken.lib.models import validate_hyper_parameters
from kraken.lib.segmentation import vectorize_lines
from kraken.lib.segmentation import vectorize_lines, to_curve

from .utils import _configure_optimizer_and_lr_scheduler

Expand Down Expand Up @@ -194,7 +195,8 @@ def __init__(self,
valid_baselines=valid_baselines,
merge_baselines=merge_baselines,
valid_regions=valid_regions,
merge_regions=merge_regions)
merge_regions=merge_regions,
return_curves=True)

for page in training_data:
train_set.add(page)
Expand All @@ -206,7 +208,8 @@ def __init__(self,
valid_baselines=valid_baselines,
merge_baselines=merge_baselines,
valid_regions=valid_regions,
merge_regions=merge_regions)
merge_regions=merge_regions,
return_curves=True)

for page in evaluation_data:
val_set.add(page)
Expand Down Expand Up @@ -246,7 +249,7 @@ def training_step(self, batch, batch_idx):
return loss

def validation_step(self, batch, batch_idx):
x, y = batch['image'], batch['target']
x, y, y_curves = batch['image'], batch['target'], batch['curves']
pred, _ = self.nn.nn(x)
# scale target to output size
y = F.interpolate(y, size=(pred.size(2), pred.size(3)), mode='nearest').int()
Expand All @@ -258,19 +261,25 @@ def validation_step(self, batch, batch_idx):
self.val_region_mean_accuracy.update(pred_reg, y_reg)
self.val_region_mean_iu.update(pred_reg, y_reg)
self.val_region_freq_iu.update(pred_reg, y_reg)
# vectorize lines
st_sep = self.nn.user_metadata['class_mapping']['aux']['_start_separator']
end_sep = self.nn.user_metadata['class_mapping']['aux']['_end_separator']
line_idxs = sorted(self.nn.user_metadata['class_mapping']['lines'].values())
for line_idx in line_idxs:
pred_bl = vectorize_lines(pred[:, [st_sep, end_sep, line_idx], ...], text_direction='horizontal')


# vectorize and match lines
for line_cls, line_idx in self.nn.user_metadata['class_mapping']['lines'].items():
pred_curves = torch.stack([to_curve(pred_bl, pred.shape[:2][-1]) for pred_bl in vectorize_lines(pred[:, [st_sep, end_sep, line_idx], ...],
text_direction='horizontal')])
cost_curves = torch.cdist(pred_curves, y_curves[line_cls], p=1).view(len(pred_curves), -1).cpu()
row_ind, col_ind = linear_sum_assignment(cost_curves)
self.val_line_dist.update(cost_curves[row_ind, col_ind])

def on_validation_epoch_end(self):
if not self.trainer.sanity_checking:
pixel_accuracy = self.val_region_px_accuracy.compute()
mean_accuracy = self.val_region_mean_accuracy.compute()
mean_iu = self.val_region_mean_iu.compute()
freq_iu = self.val_region_freq_iu.compute()
mean_line_dist = self.val_line_dist.compute()

if mean_iu > self.best_metric:
logger.debug(f'Updating best region metric from {self.best_metric} ({self.best_epoch}) to {mean_iu} ({self.current_epoch})')
Expand All @@ -283,6 +292,7 @@ def on_validation_epoch_end(self):
self.log('val_region_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_region_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_region_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_mean_line_dist', mean_line_dist, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True)

# reset metrics even if sanity checking
Expand Down

0 comments on commit 05fd70e

Please sign in to comment.