Skip to content

Commit

Permalink
patch in sentencepiece codec in training
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Apr 5, 2024
1 parent d539fde commit 8442261
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 14 deletions.
117 changes: 115 additions & 2 deletions kraken/lib/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@
"""
import logging
from collections import Counter
from typing import Dict, List, Sequence, Set, Tuple, Union
from typing import Dict, List, Sequence, Set, Tuple, Union, Optional, Iterator, TYPE_CHECKING

import io
import numpy as np
import sentencepiece as spm
from torch import IntTensor

from kraken.lib.exceptions import KrakenCodecException, KrakenEncodeException

__all__ = ['PytorchCodec']
if TYPE_CHECKING:
from os import PathLike

__all__ = ['PytorchCodec', 'SentencePieceCodec']

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -263,3 +268,111 @@ def add_labels(self, charset: Union[Dict[str, Sequence[int]], Sequence[str], str

def __repr__(self):
return f'PytorchCodec({self.c2l})'


class SentencePieceCodec(object):
"""
Builds a codec converting between code point and integer label sequences
using the SentencePiece algorithm.
The `model` and `sentences` argument are mutually exclusive.
Args:
model: path to sentencepiece model to load
sentences: Iterator of strings to use for training a sentencepiece
model
strict: Flag indicating if encoding/decoding errors should be ignored
or cause an exception.
"""
def __init__(self,
model: Optional[Union['PathLike', str]] = None,
sentences: Optional[Iterator[str]] = None,
strict: bool = False):
super().__init__()
if model and sentences:
raise ValueError('`model` and `sentences` arguments are mutually exclusive')

if model:
self.spp = spm.SentencePieceProcessor(model_file=model)
if sentences:
_model = io.BytesIO()
spm.SentencePieceTrainer.train(sentence_iterator=sentences,
model_writer=_model,
normalization_rule_name='identity',
remove_extra_whitespaces=False,
split_by_whitespace=False,
character_coverage=1.0)
self.spp = spm.SentencePieceProcessor(model_proto=_model.getvalue())

def __len__(self) -> int:
"""
Total number of input labels the codec can decode.
"""
return self.spp.vocab_size()

@property
def is_valid(self) -> bool:
"""
Returns True if the codec is prefix-free (in label space) and
non-singular (in both directions).
"""
return True

@property
def max_label(self) -> int:
"""
Returns the maximum label value.
"""
return self.spp.vocab_size() - 1

def encode(self, s: str) -> IntTensor:
"""
Encodes a string into a sequence of labels.
Args:
s: Input unicode string
Returns:
Encoded label sequence
Raises:
KrakenEncodeException: if the a subsequence is not encodable and the
codec is set to strict mode.
"""
labels = self.spp.encode(s)
if 0 in labels:
if self.strict:
raise KrakenEncodeException(f'Non-encodable sequence {s}. encountered.')
logger.warning(f'Non-encodable sequence {s} encountered.')
return IntTensor(labels)

def decode(self, labels: Sequence[Tuple[int, int, int, float]]) -> List[Tuple[str, int, int, float]]:
"""
Decodes a labelling.
Given a labelling with cuts and confidences returns a string with the
cuts and confidences aggregated across label-code point
correspondences. When decoding multilabels to code points the resulting
cuts are min/max, confidences are averaged.
Args:
labels: Input containing tuples (label, start, end,
confidence).
Returns:
A list of tuples (code point, start, end, confidence)
"""
proto = self.spp.decode_ids_as_immutable_proto([int(x[0]) for x in labels])
return [(piece.surface,) + label[1:] for piece, label in zip(proto.pieces, labels)]

def merge(self, codec) -> None:
"""
Not supported for Sentencepiece codecs
"""
raise ValueError('Merging of sentencepiece codecs is not supported.')

def add_labels(self, charset) -> None:
"""
Not supported for Sentencepiece codecs
"""
raise ValueError('Adding labels to sentencepiece codecs is not supported.')
9 changes: 6 additions & 3 deletions kraken/lib/dataset/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from kraken.containers import BaselineLine, BBoxLine, Segmentation
from kraken.lib import functional_im_transforms as F_t
from kraken.lib.codec import PytorchCodec
from kraken.lib.codec import PytorchCodec, SentencePieceCodec
from kraken.lib.exceptions import KrakenEncodeException, KrakenInputException
from kraken.lib.segmentation import extract_polygons
from kraken.lib.util import is_bitonal
Expand Down Expand Up @@ -231,7 +231,7 @@ def _apply_text_transform(self, sample) -> str:
raise KrakenInputException('empty text line')
return text

def encode(self, codec: Optional[PytorchCodec] = None) -> None:
def encode(self, codec: Optional[SentencePieceCodec] = None) -> None:
"""
Adds a codec to the dataset.
"""
Expand All @@ -249,7 +249,10 @@ def encode(self, codec: Optional[PytorchCodec] = None) -> None:
except KrakenInputException:
pass
else:
self.codec = PytorchCodec(''.join(self.alphabet.keys()))
def _iter_arrow_text():
for index in range(self._num_lines):
yield self._apply_text_transform(self.arrow_table.column('lines')[index].as_py())
self.codec = SentencePieceCodec(sentences=_iter_arrow_text())

def no_encode(self) -> None:
"""
Expand Down
12 changes: 3 additions & 9 deletions kraken/lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from kraken.containers import Segmentation
from kraken.lib import default_specs, models, progress, vgsl
from kraken.lib.codec import PytorchCodec
from kraken.lib.codec import PytorchCodec, SentencePieceCodec
from kraken.lib.dataset import (ArrowIPCRecognitionDataset, BaselineSet,
GroundTruthDataset, ImageInputTransforms,
PolygonGTDataset, collate_sequences)
Expand Down Expand Up @@ -615,14 +615,8 @@ def setup(self, stage: Optional[str] = None):
self.nn.init_weights()
self.nn.add_codec(self.train_set.dataset.codec)

val_diff = set(self.val_set.dataset.alphabet).difference(
set(self.train_set.dataset.codec.c2l.keys())
)
logger.info(f'Adding {len(val_diff)} dummy labels to validation set codec.')

val_codec = self.nn.codec.add_labels(val_diff)
self.val_set.dataset.encode(val_codec)
self.val_codec = val_codec
self.val_set.dataset.encode(self.train_set.dataset.codec)
self.val_codec = self.train_set.dataset.codec

if self.nn.one_channel_mode and self.train_set.dataset.im_mode != self.nn.one_channel_mode:
logger.warning(f'Neural network has been trained on mode {self.nn.one_channel_mode} images, '
Expand Down

0 comments on commit 8442261

Please sign in to comment.