Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better metrics for segmentation training #645

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kraken/ketos/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def segtest(ctx, model, evaluation_files, device, workers, threads, threshold,
from torch.utils.data import DataLoader

from kraken.lib.progress import KrakenProgressBar
from kraken.lib.train import BaselineSet, ImageInputTransforms
from kraken.lib.dataset import BaselineSet, ImageInputTransforms
from kraken.lib.vgsl import TorchVGSLModel

logger.info('Building test set from {} documents'.format(len(test_set) + len(evaluation_files)))
Expand Down
64 changes: 34 additions & 30 deletions kraken/lib/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torch.utils.data import Dataset
from torchvision import transforms

from kraken.lib.segmentation import scale_regions
from kraken.lib.segmentation import scale_regions, to_curve

if TYPE_CHECKING:
from kraken.containers import Segmentation
Expand All @@ -46,6 +46,25 @@
class BaselineSet(Dataset):
"""
Dataset for training a baseline/region segmentation model.

Args:
line_width: Height of the baseline in the scaled input.
padding: Tuple of ints containing the left/right, top/bottom
padding of the input images.
target_size: Target size of the image as a (height, width) tuple.
augmentation: Enable/disable augmentation.
valid_baselines: Sequence of valid baseline identifiers. If `None`
all are valid.
merge_baselines: Sequence of baseline identifiers to merge. Note
that merging occurs after entities not in valid_*
have been discarded.
valid_regions: Sequence of valid region identifiers. If `None` all
are valid.
merge_regions: Sequence of region identifiers to merge. Note that
merging occurs after entities not in valid_* have
been discarded.
return_curves: Whether to return fitted Bézier curves in addition to
the pixel heatmaps. Used during validation.
"""
def __init__(self,
line_width: int = 4,
Expand All @@ -55,27 +74,8 @@ def __init__(self,
valid_baselines: Sequence[str] = None,
merge_baselines: Dict[str, Sequence[str]] = None,
valid_regions: Sequence[str] = None,
merge_regions: Dict[str, Sequence[str]] = None):
"""
Creates a dataset for a text-line and region segmentation model.

Args:
line_width: Height of the baseline in the scaled input.
padding: Tuple of ints containing the left/right, top/bottom
padding of the input images.
target_size: Target size of the image as a (height, width) tuple.
augmentation: Enable/disable augmentation.
valid_baselines: Sequence of valid baseline identifiers. If `None`
all are valid.
merge_baselines: Sequence of baseline identifiers to merge. Note
that merging occurs after entities not in valid_*
have been discarded.
valid_regions: Sequence of valid region identifiers. If `None` all
are valid.
merge_regions: Sequence of region identifiers to merge. Note that
merging occurs after entities not in valid_* have
been discarded.
"""
merge_regions: Dict[str, Sequence[str]] = None,
return_curves: bool = False):
super().__init__()
self.imgs = []
self.im_mode = '1'
Expand All @@ -91,6 +91,7 @@ def __init__(self,
self.mreg_dict = merge_regions if merge_regions is not None else {}
self.valid_baselines = valid_baselines
self.valid_regions = valid_regions
self.return_curves = return_curves

self.aug = None
if augmentation:
Expand Down Expand Up @@ -162,16 +163,14 @@ def __getitem__(self, idx):
try:
logger.debug(f'Attempting to load {im}')
im = Image.open(im)
im, target = self.transform(im, target)
return {'image': im, 'target': target}
return self.transform(im, target)
except Exception:
self.failed_samples.add(idx)
idx = np.random.randint(0, len(self.imgs))
logger.debug(traceback.format_exc())
logger.info(f'Failed. Replacing with sample {idx}')
return self[idx]
im, target = self.transform(im, target)
return {'image': im, 'target': target}
return self.transform(im, target)

@staticmethod
def _get_ortho_line(lineseg, point, line_width, offset):
Expand All @@ -194,6 +193,7 @@ def transform(self, image, target):
start_sep_cls = self.class_mapping['aux']['_start_separator']
end_sep_cls = self.class_mapping['aux']['_end_separator']

curves = defaultdict(list)
for key, lines in target['baselines'].items():
try:
cls_idx = self.class_mapping['baselines'][key]
Expand All @@ -202,9 +202,8 @@ def transform(self, image, target):
continue
for line in lines:
# buffer out line to desired width
line = [k for k, g in groupby(line)]
line = np.array(line)*scale
shp_line = geom.LineString(line)
line = np.array([k for k, g in groupby(line)], dtype=np.float32)
shp_line = geom.LineString(line*scale)
split_offset = min(5, shp_line.length/2)
line_pol = np.array(shp_line.buffer(self.line_width/2, cap_style=2).boundary.coords, dtype=int)
rr, cc = polygon(line_pol[:, 1], line_pol[:, 0], shape=image.shape[1:])
Expand All @@ -223,6 +222,11 @@ def transform(self, image, target):
rr_s, cc_s = polygon(end_sep[:, 1], end_sep[:, 0], shape=image.shape[1:])
t[end_sep_cls, rr_s, cc_s] = 1
t[end_sep_cls, rr, cc] = 0
# Bézier curve fitting
if self.return_curves:
curves[key].append(to_curve(torch.from_numpy(line), orig_size))
for k, v in curves.items():
curves[k] = torch.stack(v)
for key, regions in target['regions'].items():
try:
cls_idx = self.class_mapping['regions'][key]
Expand All @@ -240,7 +244,7 @@ def transform(self, image, target):
o = self.aug(image=image, mask=target)
image = torch.tensor(o['image']).permute(2, 0, 1)
target = torch.tensor(o['mask']).permute(2, 0, 1)
return image, target
return {'image': image, 'target': target, 'curves': dict(curves) if self.return_curves else None}

def __len__(self):
return len(self.imgs)
59 changes: 59 additions & 0 deletions kraken/lib/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from scipy.signal import convolve2d
from scipy.spatial.distance import pdist, squareform
from shapely.ops import nearest_points, unary_union
from shapely.geometry import LineString
from shapely.validation import explain_validity
from skimage import draw, filters
from skimage.filters import sobel
Expand All @@ -40,6 +41,8 @@
from skimage.morphology import skeletonize
from skimage.transform import (AffineTransform, PiecewiseAffineTransform, warp)

from scipy.special import comb

from kraken.lib import default_specs
from kraken.lib.exceptions import KrakenInputException

Expand All @@ -49,6 +52,7 @@


_T_pil_or_np = TypeVar('_T_pil_or_np', Image.Image, np.ndarray)
_T_tensor_or_np = TypeVar('_T_tensor_or_np', torch.Tensor, np.ndarray)

logger = logging.getLogger('kraken')

Expand Down Expand Up @@ -1385,3 +1389,58 @@ def extract_polygons(im: Image.Image,
logger.error('bbox {} is outside of image bounds {}'.format(box, im.size))
raise KrakenInputException('Line outside of image bounds')
yield im.crop(box).rotate(angle, expand=True), box

###
# Bézier curve fitting
###


def Mtk(n, t, k):
return t**k * (1-t)**(n-k) * comb(n, k)


def BezierCoeff(ts):
return [[Mtk(3, t, k) for k in range(4)] for t in ts]


def bezier_fit(bl):
x = bl[:, 0]
y = bl[:, 1]
dy = y[1:] - y[:-1]
dx = x[1:] - x[:-1]
dt = (dx ** 2 + dy ** 2)**0.5
t = dt/dt.sum()
t = np.hstack(([0], t))
t = t.cumsum()

Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9)

control_points = Pseudoinverse.dot(bl) # (4,9)*(9,2) -> (4,2)
medi_ctp = control_points[1:-1, :]
return medi_ctp


def to_curve(baseline: torch.FloatTensor,
im_size: Tuple[int, int],
min_points: int = 8) -> torch.FloatTensor:
"""
Fits a polyline as a quadratic Bézier curve.

Args:
baseline: tensor of shape (S, 2) with coordinates in x, y format.
im_size: image size (W, H) used for control point normalization.
min_points: Minimal number of points in the baseline. If the input
baseline contains less than `min_points` additional points
will be interpolated at regular intervals along the line.

Returns:
Tensor of shape (8,)
"""
if len(baseline) < min_points:
ls = LineString(baseline)
baseline = torch.stack([torch.tensor(ls.interpolate(x, normalized=True).coords)[0] for x in np.linspace(0, 1, 8)])
baseline = baseline.numpy()
curve = np.concatenate(([baseline[0]], bezier_fit(baseline), [baseline[-1]]))
curve = curve/im_size
curve = curve.flatten()
return torch.from_numpy(curve.astype(baseline.dtype))
Loading