From ba5444e576db3ed46e31e2dbc0d8bd13f420d617 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 23 Jan 2023 12:40:08 +0100 Subject: [PATCH 01/68] partial implementation of new xml parser --- kraken/lib/xml.py | 434 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 430 insertions(+), 4 deletions(-) diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index 608cd7dc5..14a4bf450 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -22,7 +22,7 @@ from itertools import groupby from lxml import etree from PIL import Image -from typing import Union, Dict, Any, Sequence, Tuple +from typing import Union, Dict, Any, Sequence, Tuple, Literal from os import PathLike from collections import defaultdict @@ -52,8 +52,8 @@ # same for ALTO alto_regions = {'TextBlock': 'text', - 'IllustrationType': 'illustration', - 'GraphicalElementType': 'graphic', + 'Illustration': 'illustration', + 'GraphicalElement': 'graphic', 'ComposedBlock': 'composed'} @@ -423,7 +423,7 @@ def _parse_pointstype(coords: str) -> Sequence[Tuple[float, float]]: if boundary == page_boundary and rtype == 'text': logger.info('Skipping TextBlock with same size as page image.') continue - region_data[rtype].append(boundary) + region_data[rtype].append({'id': region.get('ID'), 'boundary': boundary}) data['regions'] = region_data tag_set = set(('default',)) @@ -476,3 +476,429 @@ def _parse_pointstype(coords: str) -> Sequence[Tuple[float, float]]: else: data['tags'] = False return data + +class XMLPage(object): + + type: Literal['baselines', 'bbox'] == 'baselines' + base_dir: Optional[Literal['L', 'R']] = None + imagename: pathlib.Path = None + _order: Dict[str, Dict[str, Any]] = None + has_tags: bool = False + _tag_set: Optional[Dict] = None + has_splits: bool = False + _split_set: Optional[List] = None + + def __init__(self, + filename: Union[str, pathlib.Path], + filetype: Literal['xml', 'alto', 'page'] = 'xml'): + super().__init__() + self.filename = Path(filename) + self.filetype = filetype + + if filetype == 'xml': + self._parse_xml() + elif filetype == 'alto': + self._parse_alto() + elif filetype == 'page': + self._parse_page() + + self._regions = {} + self._baselines = {} + self._orders = {'line_implicit': {'order': [], 'is_total': True, 'description': 'Implicit line order derived from element sequence'}, + 'region_implicit': {'order': [], 'is_total': True, 'description': 'Implicit region order derived from element sequence'}} + + def _parse_alto(self): + with open(self.filename, 'rb') as fp: + base_directory = self.filename.parent + try: + doc = etree.parse(fp) + except etree.XMLSyntaxError as e: + raise KrakenInputException('Parsing {} failed: {}'.format(self.filename, e)) + image = doc.find('.//{*}fileName') + if image is None or not image.text: + raise KrakenInputException('No valid image filename found in ALTO file {self.filename}') + self.imagename = base_directory.joinpath(image.text) + + lines = doc.findall('.//{*}TextLine') + # find all image regions in order + regions = [] + for el in doc.iterfind('./{*}Layout/{*}Page/{*}PrintSpace/{*}*'): + for block_type in alto_regions.keys(): + if el.tag.endswith(block_type): + regions.append(el) + # find overall dimensions to filter out dummy TextBlocks + ps = doc.find('./{*}Layout/{*}Page/{*}PrintSpace') + x_min = int(float(ps.get('HPOS'))) + y_min = int(float(ps.get('VPOS'))) + width = int(float(ps.get('WIDTH'))) + height = int(float(ps.get('HEIGHT'))) + page_boundary = [(x_min, y_min), + (x_min, y_min + height), + (x_min + width, y_min + height), + (x_min + width, y_min)] + + # parse tagrefs + cls_map = {} + tags = doc.find('.//{*}Tags') + if tags is not None: + for x in ['StructureTag', 'LayoutTag', 'OtherTag']: + for tag in tags.findall('./{{*}}{}'.format(x)): + cls_map[tag.get('ID')] = (x[:-3].lower(), tag.get('LABEL')) + # parse region type and coords + region_data = defaultdict(list) + for region in regions: + # try to find shape object + coords = region.find('./{*}Shape/{*}Polygon') + if coords is not None: + boundary = _parse_pointstype(coords.get('POINTS')) + elif (region.get('HPOS') is not None and region.get('VPOS') is not None and + region.get('WIDTH') is not None and region.get('HEIGHT') is not None): + # use rectangular definition + x_min = int(float(region.get('HPOS'))) + y_min = int(float(region.get('VPOS'))) + width = int(float(region.get('WIDTH'))) + height = int(float(region.get('HEIGHT'))) + boundary = [(x_min, y_min), + (x_min, y_min + height), + (x_min + width, y_min + height), + (x_min + width, y_min)] + else: + continue + rtype = region.get('TYPE') + # fall back to default region type if nothing is given + tagrefs = region.get('TAGREFS') + if tagrefs is not None and rtype is None: + for tagref in tagrefs.split(): + ttype, rtype = cls_map.get(tagref, (None, None)) + if rtype is not None and ttype: + break + if rtype is None: + rtype = alto_regions[region.tag.split('}')[-1]] + if boundary == page_boundary and rtype == 'text': + logger.info('Skipping TextBlock with same size as page image.') + continue + region_data[rtype].append({'id': region.get('ID'), 'boundary': boundary}) + # register implicit reading order + self._orders['region_implicit']['order'].append(region.get('ID')) + self.regions = region_data + + self._tag_set = set(('default',)) + self.lines = [] + for line in lines: + if line.get('BASELINE') is None: + logger.info('TextLine {} without baseline'.format(line.get('ID'))) + continue + pol = line.find('./{*}Shape/{*}Polygon') + boundary = None + if pol is not None: + try: + boundary = self._parse_alto_pointstype(pol.get('POINTS')) + except ValueError: + logger.info('TextLine {} without polygon'.format(line.get('ID'))) + else: + logger.info('TextLine {} without polygon'.format(line.get('ID'))) + + baseline = None + try: + baseline = self._parse_alto_pointstype(line.get('BASELINE')) + except ValueError: + logger.info('TextLine {} without baseline'.format(line.get('ID'))) + + text = '' + for el in line.xpath(".//*[local-name() = 'String'] | .//*[local-name() = 'SP']"): + text += el.get('CONTENT') if el.get('CONTENT') else ' ' + # find line type + tags = {'type': 'default'} + split_type = None + tagrefs = line.get('TAGREFS') + if tagrefs is not None: + for tagref in tagrefs.split(): + ttype, ltype = cls_map.get(tagref, (None, None)) + if ltype is not None: + self._tag_set.add(ltype) + if ttype == 'other': + tags['type'] = ltype + else: + tags[ttype] = ltype + if ltype in ['train', 'validation', 'test']: + split_type = ltype + self.lines.append({'id': line.get('ID'), + 'baseline': baseline, + 'boundary': boundary, + 'text': text, + 'tags': tags, + 'split': split_type}) + # register implicit reading order + self._orders['line_implicit']['order'].append(line.get('ID')) + + if len(self._tag_set) > 1: + self.has_tags = True + else: + self.has_tags = False + + # parse explicit reading orders if they exist + ro_el = doc.find('.//{*}ReadingOrder') + if ro_el is not None: + reading_orders = ro_el.getchildren() + # UnorderedGroup at top-level => treated as multiple reading orders + if len(reading_orders) == 1 and reading_orders[0].tag.endswith('UnorderedGroup'): + reading_orders = reading_orders.getchildren() + else: + reading_orders = [reading_orders] + def _parse_group(el): + _ro = [] + if el.tag.endswith('UnorderedGroup'): + _ro.append([_parse_group(x) for x in el.iterchildren()]) + is_total = False + elif el.tag.endswith('OrderedGroup'): + _ro.extend(_parse_group(x) for x in el.iterchildren()) + else: + return el.get('REF') + return _ro + + for ro in reading_orders: + is_total = True + joint_order = _parse_group(ro) + tag = ro.get('TAGREFS') + self._orders[ro.get('ID')] = {'order': joint_order, + 'is_total': is_total, + 'description': cls_map[tag] if tag and tag in cls_map else ''} + + def _parse_page(self): + with open(self.filename, 'rb') as fp: + base_directory = self.filename.parent + + try: + doc = etree.parse(fp) + except etree.XMLSyntaxError as e: + raise KrakenInputException('Parsing {} failed: {}'.format(self.filename, e)) + image = doc.find('.//{*}Page') + if image is None or image.get('imageFilename') is None: + raise KrakenInputException('No valid image filename found in PageXML file {}'.format(self.filename)) + try: + self.base_dir = {'left-to-right': 'L', + 'right-to-left': 'R', + 'top-to-bottom': 'L', + 'bottom-to-top': 'R', + None: None}[image.get('readingDirection')] + except KeyError: + logger.warning(f'Invalid value {image.get("readingDirection")} encountered in page-level reading direction.') + lines = doc.findall('.//{*}TextLine') + self.imagename = base_dir.joinpath(image.get('imageFilename')) + # find all image regions + regions = [] + for x in page_regions.keys(): + regions.extend(doc.findall('.//{{*}}{}'.format(x))) + # parse region type and coords + region_data = defaultdict(list) + tr_region_order = [] + for region in regions: + coords = region.find('{*}Coords') + if coords is not None and not coords.get('points').isspace() and len(coords.get('points')): + try: + coords = _parse_coords(coords.get('points')) + except Exception: + logger.warning('Region {} without coordinates'.format(region.get('id'))) + continue + else: + logger.warning('Region {} without coordinates'.format(region.get('id'))) + continue + rtype = region.get('type') + # parse transkribus-style custom field if possible + custom_str = region.get('custom') + if not rtype and custom_str: + cs = _parse_page_custom(custom_str) + if 'structure' in cs and 'type' in cs['structure']: + rtype = cs['structure']['type'] + # transkribus-style reading order + if 'readingOrder' in cs and 'index'in cs['readingOrder']: + tr_region_order.append((region.get('id'), int(cs['readingOrder']['index']))) + # fall back to default region type if nothing is given + if not rtype: + rtype = page_regions[region.tag.split('}')[-1]] + region_data[rtype].append({'id': region.get('id'), 'boundary': coords}) + # register implicit reading order + self._orders['region_implicit']['order'].append(region.get('id')) + # add transkribus-style region order + self._order['region_transkribus'] = {'order': [x[1] for x in sorted(tr_region_order, key=lambda k: k[0])], + 'is_total': True if len(set(map(lambda x: x[0], tr_region_order))) == len(tr_region_order) else False, + 'description': 'Explicit region order from `custom` attribute'} + + self.regions = region_data + + # parse line information + self._tag_set = set(('default',)) + tmp_transkribus_line_order = defaultdict(list) + valid_tr_lo = True + for line in lines: + pol = line.find('./{*}Coords') + boundary = None + if pol is not None and not pol.get('points').isspace() and len(pol.get('points')): + try: + boundary = self._parse_coords(pol.get('points')) + except Exception: + logger.info('TextLine {} without polygon'.format(line.get('id'))) + else: + logger.info('TextLine {} without polygon'.format(line.get('id'))) + base = line.find('./{*}Baseline') + baseline = None + if base is not None and not base.get('points').isspace() and len(base.get('points')): + try: + baseline = self._parse_coords(base.get('points')) + except Exception: + logger.info('TextLine {} without baseline'.format(line.get('id'))) + continue + else: + logger.info('TextLine {} without baseline'.format(line.get('id'))) + continue + text = '' + manual_transcription = line.find('./{*}TextEquiv') + if manual_transcription is not None: + transcription = manual_transcription + else: + transcription = line + for el in transcription.findall('.//{*}Unicode'): + if el.text: + text += el.text + # retrieve line tags if custom string is set and contains + tags = {'type': 'default'} + split_type = None + custom_str = line.get('custom') + if custom_str: + cs = _parse_page_custom(custom_str) + if 'structure' in cs and 'type' in cs['structure']: + tags['type'] = cs['structure']['type'] + self._tag_set.add(tags['type']) + # retrieve data split if encoded in custom string. + if 'split' in cs and 'type' in cs['split'] and cs['split']['type'] in ['train', 'validation', 'test']: + split_type = cs['split']['type'] + tags['split'] = split_type + self._tag_set.add(split_type) + if 'readingOrder' in cs and 'index' in cs['readingOrder']: + # look up region index from parent + reg_cus = _parse_page_custom(line.getparent().get('custom')) + if 'readingOrder' not in reg_cus or 'index' not in reg_cus['readingOrder']: + logger.warning('Incomplete `custom` attribute reading order found.') + valid_tr_lo = False + else: + tmp_transkribus_line_order[int(reg_cus['readingOrder']['index'])].append((int(cs['readingOrder']['index']), line.get('id'))) + + self.lines.append({'id': line.get('id'), + 'baseline': baseline, + 'boundary': boundary, + 'text': text, + 'split': split_type, + 'tags': tags}) + # register implicit reading order + self._orders['line_implicit']['order'].append(line.get('id')) + if tmp_transkribus_line_order: + # sort by regions + tmp_reg_order = sorted(((k, v) for k, v in tmp_transkribus_line_order.items()), key=lambda k: k[0]) + # flatten + tr_line_order = [] + for _, lines in tmp_reg_order: + tr_line_order.extend([x[1] for x in sorted(lines, key=lambda k: k[0])]) + self._order['line_transkribus'] = {'order': tr_line_order, + 'is_total': True, + 'description': 'Explicit line order from `custom` attribute'} + + # parse explicit reading orders if they exist + ro_el = doc.find('.//{*}ReadingOrder') + if ro_el is not None: + reading_orders = ro_el.getchildren() + # UnorderedGroup at top-level => treated as multiple reading orders + if len(reading_orders) == 1 and reading_orders[0].tag.endswith('UnorderedGroup'): + reading_orders = reading_orders.getchildren() + else: + reading_orders = [reading_orders] + def _parse_group(el): + _ro = [] + if el.tag.endswith('UnorderedGroup'): + _ro.append([_parse_group(x) for x in el.iterchildren()]) + is_total = False + elif el.tag.endswith('OrderedGroup'): + _ro.extend(_parse_group(x) for x in el.iterchildren()) + else: + return el.get('regionRef') + return _ro + + for ro in reading_orders: + is_total = True + self._orders[ro.get('id')] = {'order': _parse_group(ro), + 'is_total': is_total, + 'description': ro.get('caption') if ro.get('caption') else ''} + + + if len(self._tag_set) > 1: + self.has_tags = True + else: + self.has_tags = False + + @property + def regions(self): + return self._regions + + @property + def baselines(self): + return self._baselines + + def get_baselines_by_region(self, region): + pass + + def get_baselines_by_tag(self, key, value): + pass + + def get_baselines_by_split(self, split: Literal['train', 'validation', 'test']): + pass + + @property + def tags(self): + return self._tag_set + + @property + def splits(self): + return self._split_set + + @staticmethod + def _parse_alto_pointstype(coords: str) -> Sequence[Tuple[float, float]]: + """ + ALTO's PointsType is underspecified so a variety of serializations are valid: + + x0, y0 x1, y1 ... + x0 y0 x1 y1 ... + (x0, y0) (x1, y1) ... + (x0 y0) (x1 y1) ... + + Returns: + A list of tuples [(x0, y0), (x1, y1), ...] + """ + float_re = re.compile(r'[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?') + points = [float(point.group()) for point in float_re.finditer(coords)] + if len(points) % 2: + raise ValueError(f'Odd number of points in points sequence: {points}') + pts = zip(points[::2], points[1::2]) + return [k for k, g in groupby(pts)] + + @staticmethod + def _parse_page_custom(s): + o = {} + s = s.strip() + l_chunks = [l_chunk for l_chunk in s.split('}') if l_chunk.strip()] + if l_chunks: + for chunk in l_chunks: + tag, vals = chunk.split('{') + tag_vals = {} + vals = [val.strip() for val in vals.split(';') if val.strip()] + for val in vals: + key, *val = val.split(':') + tag_vals[key] = ":".join(val) + o[tag.strip()] = tag_vals + return o + + @staticmethod + def _parse_page_coords(coords): + points = [x for x in coords.split(' ')] + points = [int(c) for point in points for c in point.split(',')] + pts = zip(points[::2], points[1::2]) + return [k for k, g in groupby(pts)] + From c36c721e23626c5461b915ba5257f96ffee1bbf6 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Thu, 26 Jan 2023 15:08:11 +0100 Subject: [PATCH 02/68] skip ROs in ALTO with sub-line elements --- kraken/lib/xml.py | 170 +++++++++++++++++++++++++++------------------- 1 file changed, 101 insertions(+), 69 deletions(-) diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index 14a4bf450..280b2cf7b 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -22,12 +22,11 @@ from itertools import groupby from lxml import etree from PIL import Image -from typing import Union, Dict, Any, Sequence, Tuple, Literal +from typing import Union, Dict, Any, Sequence, Tuple, Literal, Optional, List from os import PathLike from collections import defaultdict from kraken.lib.segmentation import calculate_polygonal_environment -from kraken.lib.exceptions import KrakenInputException logger = logging.getLogger(__name__) @@ -89,7 +88,7 @@ def preparse_xml_data(filenames: Sequence[Union[str, PathLike]], for fn in filenames: try: data = parse_fn(fn) - except KrakenInputException as e: + except ValueError as e: logger.warning(e) continue try: @@ -482,7 +481,7 @@ class XMLPage(object): type: Literal['baselines', 'bbox'] == 'baselines' base_dir: Optional[Literal['L', 'R']] = None imagename: pathlib.Path = None - _order: Dict[str, Dict[str, Any]] = None + _orders: Dict[str, Dict[str, Any]] = None has_tags: bool = False _tag_set: Optional[Dict] = None has_splits: bool = False @@ -495,6 +494,11 @@ def __init__(self, self.filename = Path(filename) self.filetype = filetype + self._regions = {} + self._baselines = {} + self._orders = {'line_implicit': {'order': [], 'is_total': True, 'description': 'Implicit line order derived from element sequence'}, + 'region_implicit': {'order': [], 'is_total': True, 'description': 'Implicit region order derived from element sequence'}} + if filetype == 'xml': self._parse_xml() elif filetype == 'alto': @@ -502,10 +506,18 @@ def __init__(self, elif filetype == 'page': self._parse_page() - self._regions = {} - self._baselines = {} - self._orders = {'line_implicit': {'order': [], 'is_total': True, 'description': 'Implicit line order derived from element sequence'}, - 'region_implicit': {'order': [], 'is_total': True, 'description': 'Implicit region order derived from element sequence'}} + def _parse_xml(self): + with open(self.filename, 'rb') as fp: + try: + doc = etree.parse(fp) + except etree.XMLSyntaxError as e: + raise ValueError(f'Parsing {self.filename} failed: {e}') + if doc.getroot().tag.endswith('alto'): + return self._parse_alto() + elif doc.getroot().tag.endswith('PcGts'): + return self._parse_page() + else: + raise ValueError(f'Unknown XML format in {self.filename}') def _parse_alto(self): with open(self.filename, 'rb') as fp: @@ -513,10 +525,10 @@ def _parse_alto(self): try: doc = etree.parse(fp) except etree.XMLSyntaxError as e: - raise KrakenInputException('Parsing {} failed: {}'.format(self.filename, e)) + raise ValueError('Parsing {} failed: {}'.format(self.filename, e)) image = doc.find('.//{*}fileName') if image is None or not image.text: - raise KrakenInputException('No valid image filename found in ALTO file {self.filename}') + raise ValueError('No valid image filename found in ALTO file {self.filename}') self.imagename = base_directory.joinpath(image.text) lines = doc.findall('.//{*}TextLine') @@ -550,7 +562,7 @@ def _parse_alto(self): # try to find shape object coords = region.find('./{*}Shape/{*}Polygon') if coords is not None: - boundary = _parse_pointstype(coords.get('POINTS')) + boundary = self._parse_alto_pointstype(coords.get('POINTS')) elif (region.get('HPOS') is not None and region.get('VPOS') is not None and region.get('WIDTH') is not None and region.get('HEIGHT') is not None): # use rectangular definition @@ -580,7 +592,7 @@ def _parse_alto(self): region_data[rtype].append({'id': region.get('ID'), 'boundary': boundary}) # register implicit reading order self._orders['region_implicit']['order'].append(region.get('ID')) - self.regions = region_data + self._regions = region_data self._tag_set = set(('default',)) self.lines = [] @@ -639,30 +651,42 @@ def _parse_alto(self): # parse explicit reading orders if they exist ro_el = doc.find('.//{*}ReadingOrder') if ro_el is not None: - reading_orders = ro_el.getchildren() - # UnorderedGroup at top-level => treated as multiple reading orders - if len(reading_orders) == 1 and reading_orders[0].tag.endswith('UnorderedGroup'): - reading_orders = reading_orders.getchildren() - else: - reading_orders = [reading_orders] - def _parse_group(el): - _ro = [] - if el.tag.endswith('UnorderedGroup'): - _ro.append([_parse_group(x) for x in el.iterchildren()]) - is_total = False - elif el.tag.endswith('OrderedGroup'): - _ro.extend(_parse_group(x) for x in el.iterchildren()) - else: - return el.get('REF') - return _ro - - for ro in reading_orders: - is_total = True - joint_order = _parse_group(ro) - tag = ro.get('TAGREFS') - self._orders[ro.get('ID')] = {'order': joint_order, - 'is_total': is_total, - 'description': cls_map[tag] if tag and tag in cls_map else ''} + reading_orders = ro_el.getchildren() + # UnorderedGroup at top-level => treated as multiple reading orders + if len(reading_orders) == 1 and reading_orders[0].tag.endswith('UnorderedGroup'): + reading_orders = reading_orders[0].getchildren() + else: + reading_orders = [reading_orders] + def _parse_group(el): + _ro = [] + if el.tag.endswith('UnorderedGroup'): + _ro.append([_parse_group(x) for x in el.iterchildren()]) + is_total = False + elif el.tag.endswith('OrderedGroup'): + _ro.extend(_parse_group(x) for x in el.iterchildren()) + else: + ref = el.get('REF') + res = doc.find(f'.//{{*}}*[@ID="{ref}"]') + if res is None: + logger.warning(f'Nonexistant element with ID {ref} in reading order. Skipping RO {ro.get("ID")}.') + is_valid = False + return _ro + tag = res.tag.split('}')[-1] + if tag not in alto_regions.keys() and tag != 'TextLine': + logger.warning(f'Sub-line element with ID {ref} in reading order. Skipping RO {ro.get("ID")}.') + is_valid = False + return _ro + + for ro in reading_orders: + is_total = True + is_valid = True + joint_order = _parse_group(ro) + if is_valid: + tag = ro.get('TAGREFS') + self._orders[ro.get('ID')] = {'order': joint_order, + 'is_total': is_total, + 'description': cls_map[tag] if tag and tag in cls_map else ''} + self.filetype = 'alto' def _parse_page(self): with open(self.filename, 'rb') as fp: @@ -671,10 +695,10 @@ def _parse_page(self): try: doc = etree.parse(fp) except etree.XMLSyntaxError as e: - raise KrakenInputException('Parsing {} failed: {}'.format(self.filename, e)) + raise ValueError(f'Parsing {self.filename} failed: {e}') image = doc.find('.//{*}Page') if image is None or image.get('imageFilename') is None: - raise KrakenInputException('No valid image filename found in PageXML file {}'.format(self.filename)) + raise ValueError(f'No valid image filename found in PageXML file {self.filename}') try: self.base_dir = {'left-to-right': 'L', 'right-to-left': 'R', @@ -684,19 +708,17 @@ def _parse_page(self): except KeyError: logger.warning(f'Invalid value {image.get("readingDirection")} encountered in page-level reading direction.') lines = doc.findall('.//{*}TextLine') - self.imagename = base_dir.joinpath(image.get('imageFilename')) + self.imagename = base_directory.joinpath(image.get('imageFilename')) # find all image regions - regions = [] - for x in page_regions.keys(): - regions.extend(doc.findall('.//{{*}}{}'.format(x))) + regions = [reg for reg in image.iterfind('./{*}*')] # parse region type and coords region_data = defaultdict(list) tr_region_order = [] for region in regions: - coords = region.find('{*}Coords') + coords = region.find('./{*}Coords') if coords is not None and not coords.get('points').isspace() and len(coords.get('points')): try: - coords = _parse_coords(coords.get('points')) + coords = self._parse_page_coords(coords.get('points')) except Exception: logger.warning('Region {} without coordinates'.format(region.get('id'))) continue @@ -706,9 +728,9 @@ def _parse_page(self): rtype = region.get('type') # parse transkribus-style custom field if possible custom_str = region.get('custom') - if not rtype and custom_str: - cs = _parse_page_custom(custom_str) - if 'structure' in cs and 'type' in cs['structure']: + if custom_str: + cs = self._parse_page_custom(custom_str) + if not rtype and 'structure' in cs and 'type' in cs['structure']: rtype = cs['structure']['type'] # transkribus-style reading order if 'readingOrder' in cs and 'index'in cs['readingOrder']: @@ -720,11 +742,11 @@ def _parse_page(self): # register implicit reading order self._orders['region_implicit']['order'].append(region.get('id')) # add transkribus-style region order - self._order['region_transkribus'] = {'order': [x[1] for x in sorted(tr_region_order, key=lambda k: k[0])], - 'is_total': True if len(set(map(lambda x: x[0], tr_region_order))) == len(tr_region_order) else False, - 'description': 'Explicit region order from `custom` attribute'} + self._orders['region_transkribus'] = {'order': [x[0] for x in sorted(tr_region_order, key=lambda k: k[1])], + 'is_total': True if len(set(map(lambda x: x[0], tr_region_order))) == len(tr_region_order) else False, + 'description': 'Explicit region order from `custom` attribute'} - self.regions = region_data + self._regions = region_data # parse line information self._tag_set = set(('default',)) @@ -744,7 +766,7 @@ def _parse_page(self): baseline = None if base is not None and not base.get('points').isspace() and len(base.get('points')): try: - baseline = self._parse_coords(base.get('points')) + baseline = self._parse_page_coords(base.get('points')) except Exception: logger.info('TextLine {} without baseline'.format(line.get('id'))) continue @@ -765,7 +787,7 @@ def _parse_page(self): split_type = None custom_str = line.get('custom') if custom_str: - cs = _parse_page_custom(custom_str) + cs = self._parse_page_custom(custom_str) if 'structure' in cs and 'type' in cs['structure']: tags['type'] = cs['structure']['type'] self._tag_set.add(tags['type']) @@ -776,19 +798,19 @@ def _parse_page(self): self._tag_set.add(split_type) if 'readingOrder' in cs and 'index' in cs['readingOrder']: # look up region index from parent - reg_cus = _parse_page_custom(line.getparent().get('custom')) + reg_cus = self._parse_page_custom(line.getparent().get('custom')) if 'readingOrder' not in reg_cus or 'index' not in reg_cus['readingOrder']: logger.warning('Incomplete `custom` attribute reading order found.') valid_tr_lo = False else: tmp_transkribus_line_order[int(reg_cus['readingOrder']['index'])].append((int(cs['readingOrder']['index']), line.get('id'))) - self.lines.append({'id': line.get('id'), - 'baseline': baseline, - 'boundary': boundary, - 'text': text, - 'split': split_type, - 'tags': tags}) + self._baselines[line.get('id')] = {'baseline': baseline, + 'boundary': boundary, + 'text': text, + 'split': split_type, + 'tags': tags} + # register implicit reading order self._orders['line_implicit']['order'].append(line.get('id')) if tmp_transkribus_line_order: @@ -798,9 +820,9 @@ def _parse_page(self): tr_line_order = [] for _, lines in tmp_reg_order: tr_line_order.extend([x[1] for x in sorted(lines, key=lambda k: k[0])]) - self._order['line_transkribus'] = {'order': tr_line_order, - 'is_total': True, - 'description': 'Explicit line order from `custom` attribute'} + self._orders['line_transkribus'] = {'order': tr_line_order, + 'is_total': True, + 'description': 'Explicit line order from `custom` attribute'} # parse explicit reading orders if they exist ro_el = doc.find('.//{*}ReadingOrder') @@ -828,11 +850,12 @@ def _parse_group(el): 'is_total': is_total, 'description': ro.get('caption') if ro.get('caption') else ''} + if len(self._tag_set) > 1: + self.has_tags = True + else: + self.has_tags = False - if len(self._tag_set) > 1: - self.has_tags = True - else: - self.has_tags = False + self.filetype = 'page' @property def regions(self): @@ -842,14 +865,18 @@ def regions(self): def baselines(self): return self._baselines + @property + def reading_orders(self): + return self._orders + def get_baselines_by_region(self, region): pass def get_baselines_by_tag(self, key, value): - pass + return {k: v for k, v in self._baselines.items() if v['tags'].get(key) == value} def get_baselines_by_split(self, split: Literal['train', 'validation', 'test']): - pass + return {k: v for k, v in self._baselines.items() if v['tags'].get(key) == split} @property def tags(self): @@ -902,3 +929,8 @@ def _parse_page_coords(coords): pts = zip(points[::2], points[1::2]) return [k for k, g in groupby(pts)] + def __str__(self): + return f'XMLPage {self.filename} (format: {self.filetype}, image: {self.imagename})' + + def __repr__(self): + return f'XMLPage(filename={self.filename}, filetype={self.filetype})' From b1977055e69f3a1a6f444839971d89665299828b Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Jan 2023 11:47:13 +0100 Subject: [PATCH 03/68] wip --- kraken/lib/xml.py | 300 ++++++++++++++++++++++++++-------------------- 1 file changed, 171 insertions(+), 129 deletions(-) diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index 280b2cf7b..2d6f72e29 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -495,7 +495,7 @@ def __init__(self, self.filetype = filetype self._regions = {} - self._baselines = {} + self._lines = {} self._orders = {'line_implicit': {'order': [], 'is_total': True, 'description': 'Implicit line order derived from element sequence'}, 'region_implicit': {'order': [], 'is_total': True, 'description': 'Implicit region order derived from element sequence'}} @@ -531,7 +531,6 @@ def _parse_alto(self): raise ValueError('No valid image filename found in ALTO file {self.filename}') self.imagename = base_directory.joinpath(image.text) - lines = doc.findall('.//{*}TextLine') # find all image regions in order regions = [] for el in doc.iterfind('./{*}Layout/{*}Page/{*}PrintSpace/{*}*'): @@ -556,6 +555,9 @@ def _parse_alto(self): for x in ['StructureTag', 'LayoutTag', 'OtherTag']: for tag in tags.findall('./{{*}}{}'.format(x)): cls_map[tag.get('ID')] = (x[:-3].lower(), tag.get('LABEL')) + + self._tag_set = set(('default',)) + # parse region type and coords region_data = defaultdict(list) for region in regions: @@ -574,8 +576,6 @@ def _parse_alto(self): (x_min, y_min + height), (x_min + width, y_min + height), (x_min + width, y_min)] - else: - continue rtype = region.get('TYPE') # fall back to default region type if nothing is given tagrefs = region.get('TAGREFS') @@ -586,62 +586,61 @@ def _parse_alto(self): break if rtype is None: rtype = alto_regions[region.tag.split('}')[-1]] - if boundary == page_boundary and rtype == 'text': - logger.info('Skipping TextBlock with same size as page image.') - continue - region_data[rtype].append({'id': region.get('ID'), 'boundary': boundary}) + region_id = region.get('ID') + region_data[rtype].append({'id': region_id, 'boundary': boundary}) # register implicit reading order - self._orders['region_implicit']['order'].append(region.get('ID')) - self._regions = region_data + self._orders['region_implicit']['order'].append(region_id) - self._tag_set = set(('default',)) - self.lines = [] - for line in lines: - if line.get('BASELINE') is None: - logger.info('TextLine {} without baseline'.format(line.get('ID'))) - continue - pol = line.find('./{*}Shape/{*}Polygon') - boundary = None - if pol is not None: + # parse lines in region + for line in region.iterfind('./{*}TextLine'): + if line.get('BASELINE') is None: + logger.info('TextLine {} without baseline'.format(line.get('ID'))) + continue + pol = line.find('./{*}Shape/{*}Polygon') + boundary = None + if pol is not None: + try: + boundary = self._parse_alto_pointstype(pol.get('POINTS')) + except ValueError: + logger.info('TextLine {} without polygon'.format(line.get('ID'))) + else: + logger.info('TextLine {} without polygon'.format(line.get('ID'))) + + baseline = None try: - boundary = self._parse_alto_pointstype(pol.get('POINTS')) + baseline = self._parse_alto_pointstype(line.get('BASELINE')) except ValueError: - logger.info('TextLine {} without polygon'.format(line.get('ID'))) - else: - logger.info('TextLine {} without polygon'.format(line.get('ID'))) + logger.info('TextLine {} without baseline'.format(line.get('ID'))) + + text = '' + for el in line.xpath(".//*[local-name() = 'String'] | .//*[local-name() = 'SP']"): + text += el.get('CONTENT') if el.get('CONTENT') else ' ' + # find line type + tags = {'type': 'default'} + split_type = None + tagrefs = line.get('TAGREFS') + if tagrefs is not None: + for tagref in tagrefs.split(): + ttype, ltype = cls_map.get(tagref, (None, None)) + if ltype is not None: + self._tag_set.add(ltype) + if ttype == 'other': + tags['type'] = ltype + else: + tags[ttype] = ltype + if ltype in ['train', 'validation', 'test']: + split_type = ltype + self._lines[line.get('ID')] = {'baseline': baseline, + 'boundary': boundary, + 'text': text, + 'tags': tags, + 'split': split_type, + 'region': region_id} + # register implicit reading order + self._orders['line_implicit']['order'].append(line.get('ID')) + + self._regions = region_data - baseline = None - try: - baseline = self._parse_alto_pointstype(line.get('BASELINE')) - except ValueError: - logger.info('TextLine {} without baseline'.format(line.get('ID'))) - - text = '' - for el in line.xpath(".//*[local-name() = 'String'] | .//*[local-name() = 'SP']"): - text += el.get('CONTENT') if el.get('CONTENT') else ' ' - # find line type - tags = {'type': 'default'} - split_type = None - tagrefs = line.get('TAGREFS') - if tagrefs is not None: - for tagref in tagrefs.split(): - ttype, ltype = cls_map.get(tagref, (None, None)) - if ltype is not None: - self._tag_set.add(ltype) - if ttype == 'other': - tags['type'] = ltype - else: - tags[ttype] = ltype - if ltype in ['train', 'validation', 'test']: - split_type = ltype - self.lines.append({'id': line.get('ID'), - 'baseline': baseline, - 'boundary': boundary, - 'text': text, - 'tags': tags, - 'split': split_type}) - # register implicit reading order - self._orders['line_implicit']['order'].append(line.get('ID')) if len(self._tag_set) > 1: self.has_tags = True @@ -707,13 +706,17 @@ def _parse_page(self): None: None}[image.get('readingDirection')] except KeyError: logger.warning(f'Invalid value {image.get("readingDirection")} encountered in page-level reading direction.') - lines = doc.findall('.//{*}TextLine') self.imagename = base_directory.joinpath(image.get('imageFilename')) # find all image regions regions = [reg for reg in image.iterfind('./{*}*')] # parse region type and coords region_data = defaultdict(list) tr_region_order = [] + + self._tag_set = set(('default',)) + tmp_transkribus_line_order = defaultdict(list) + valid_tr_lo = True + for region in regions: coords = region.find('./{*}Coords') if coords is not None and not coords.get('points').isspace() and len(coords.get('points')): @@ -721,10 +724,10 @@ def _parse_page(self): coords = self._parse_page_coords(coords.get('points')) except Exception: logger.warning('Region {} without coordinates'.format(region.get('id'))) - continue + coords = None else: logger.warning('Region {} without coordinates'.format(region.get('id'))) - continue + coords = None rtype = region.get('type') # parse transkribus-style custom field if possible custom_str = region.get('custom') @@ -741,78 +744,78 @@ def _parse_page(self): region_data[rtype].append({'id': region.get('id'), 'boundary': coords}) # register implicit reading order self._orders['region_implicit']['order'].append(region.get('id')) - # add transkribus-style region order - self._orders['region_transkribus'] = {'order': [x[0] for x in sorted(tr_region_order, key=lambda k: k[1])], - 'is_total': True if len(set(map(lambda x: x[0], tr_region_order))) == len(tr_region_order) else False, - 'description': 'Explicit region order from `custom` attribute'} - - self._regions = region_data - # parse line information - self._tag_set = set(('default',)) - tmp_transkribus_line_order = defaultdict(list) - valid_tr_lo = True - for line in lines: - pol = line.find('./{*}Coords') - boundary = None - if pol is not None and not pol.get('points').isspace() and len(pol.get('points')): - try: - boundary = self._parse_coords(pol.get('points')) - except Exception: + # parse line information + for line in region.iterfind('./{*}TextLine'): + pol = line.find('./{*}Coords') + boundary = None + if pol is not None and not pol.get('points').isspace() and len(pol.get('points')): + try: + boundary = self._parse_page_coords(pol.get('points')) + except Exception: + logger.info('TextLine {} without polygon'.format(line.get('id'))) + else: logger.info('TextLine {} without polygon'.format(line.get('id'))) - else: - logger.info('TextLine {} without polygon'.format(line.get('id'))) - base = line.find('./{*}Baseline') - baseline = None - if base is not None and not base.get('points').isspace() and len(base.get('points')): - try: - baseline = self._parse_page_coords(base.get('points')) - except Exception: + base = line.find('./{*}Baseline') + baseline = None + if base is not None and not base.get('points').isspace() and len(base.get('points')): + try: + baseline = self._parse_page_coords(base.get('points')) + except Exception: + logger.info('TextLine {} without baseline'.format(line.get('id'))) + continue + else: logger.info('TextLine {} without baseline'.format(line.get('id'))) continue - else: - logger.info('TextLine {} without baseline'.format(line.get('id'))) - continue - text = '' - manual_transcription = line.find('./{*}TextEquiv') - if manual_transcription is not None: - transcription = manual_transcription - else: - transcription = line - for el in transcription.findall('.//{*}Unicode'): - if el.text: - text += el.text - # retrieve line tags if custom string is set and contains - tags = {'type': 'default'} - split_type = None - custom_str = line.get('custom') - if custom_str: - cs = self._parse_page_custom(custom_str) - if 'structure' in cs and 'type' in cs['structure']: - tags['type'] = cs['structure']['type'] - self._tag_set.add(tags['type']) - # retrieve data split if encoded in custom string. - if 'split' in cs and 'type' in cs['split'] and cs['split']['type'] in ['train', 'validation', 'test']: - split_type = cs['split']['type'] - tags['split'] = split_type - self._tag_set.add(split_type) - if 'readingOrder' in cs and 'index' in cs['readingOrder']: - # look up region index from parent - reg_cus = self._parse_page_custom(line.getparent().get('custom')) - if 'readingOrder' not in reg_cus or 'index' not in reg_cus['readingOrder']: - logger.warning('Incomplete `custom` attribute reading order found.') - valid_tr_lo = False - else: - tmp_transkribus_line_order[int(reg_cus['readingOrder']['index'])].append((int(cs['readingOrder']['index']), line.get('id'))) + text = '' + manual_transcription = line.find('./{*}TextEquiv') + if manual_transcription is not None: + transcription = manual_transcription + else: + transcription = line + for el in transcription.findall('.//{*}Unicode'): + if el.text: + text += el.text + # retrieve line tags if custom string is set and contains + tags = {'type': 'default'} + split_type = None + custom_str = line.get('custom') + if custom_str: + cs = self._parse_page_custom(custom_str) + if 'structure' in cs and 'type' in cs['structure']: + tags['type'] = cs['structure']['type'] + self._tag_set.add(tags['type']) + # retrieve data split if encoded in custom string. + if 'split' in cs and 'type' in cs['split'] and cs['split']['type'] in ['train', 'validation', 'test']: + split_type = cs['split']['type'] + tags['split'] = split_type + self._tag_set.add(split_type) + if 'readingOrder' in cs and 'index' in cs['readingOrder']: + # look up region index from parent + reg_cus = self._parse_page_custom(line.getparent().get('custom')) + if 'readingOrder' not in reg_cus or 'index' not in reg_cus['readingOrder']: + logger.warning('Incomplete `custom` attribute reading order found.') + valid_tr_lo = False + else: + tmp_transkribus_line_order[int(reg_cus['readingOrder']['index'])].append((int(cs['readingOrder']['index']), line.get('id'))) - self._baselines[line.get('id')] = {'baseline': baseline, + self._lines[line.get('id')] = {'baseline': baseline, 'boundary': boundary, 'text': text, 'split': split_type, - 'tags': tags} + 'tags': tags, + 'region': region.get('id')} + + # register implicit reading order + self._orders['line_implicit']['order'].append(line.get('id')) + + # add transkribus-style region order + self._orders['region_transkribus'] = {'order': [x[0] for x in sorted(tr_region_order, key=lambda k: k[1])], + 'is_total': True if len(set(map(lambda x: x[0], tr_region_order))) == len(tr_region_order) else False, + 'description': 'Explicit region order from `custom` attribute'} + + self._regions = region_data - # register implicit reading order - self._orders['line_implicit']['order'].append(line.get('id')) if tmp_transkribus_line_order: # sort by regions tmp_reg_order = sorted(((k, v) for k, v in tmp_transkribus_line_order.items()), key=lambda k: k[0]) @@ -862,21 +865,60 @@ def regions(self): return self._regions @property - def baselines(self): - return self._baselines + def lines(self): + return self._lines @property def reading_orders(self): return self._orders - def get_baselines_by_region(self, region): - pass + def get_sorted_lines(self, ro='line_implicit'): + """ + Returns ordered baselines from particular reading order. + """ + if ro not in self.reading_orders: + raise ValueError(f'Unknown reading order {ro}') + def _traverse_ro(el): + _ro = [] + if isinstance(el, list): + _ro.append([_traverse_ro(x) for x in el]) + else: + # if line directly append to ro + if el in self.lines: + return self.lines[el] + # substitute lines if region in RO + elif el in [reg['id'] for regs in doc.regions.values() for reg in regs]: + _ro.extend(self.get_lines_by_region(el)) + else: + raise ValueError(f'Invalid reading order {ro}') + return _ro + + _ro = self.reading_orders[ro] + return _traverse_ro(_ro['order']) + + def get_sorted_regions(self, ro='region_implicit'): + """ + Returns ordered regions from particular reading order. + """ + + + def get_sorted_lines_by_region(self, region, ro='line_implicit'): + """ + Returns ordered lines in region. + """ + if self.reading_orders[ro]['is_total'] is False: + raise ValueError('Fetching lines by region of a non-total order is not supported') + lines = [(id, line) for id, line in self._lines.items() if line['region'] == region] + for line in lines: + if line[0] not in self.reading_orders[ro]['order']: + raise ValueError('Fetching lines by region is only possible for flat orders') + return sorted(lines, key=lambda k: self.reading_orders[ro]['order'].index(k[0])) - def get_baselines_by_tag(self, key, value): - return {k: v for k, v in self._baselines.items() if v['tags'].get(key) == value} + def get_lines_by_tag(self, key, value): + return {k: v for k, v in self._lines.items() if v['tags'].get(key) == value} - def get_baselines_by_split(self, split: Literal['train', 'validation', 'test']): - return {k: v for k, v in self._baselines.items() if v['tags'].get(key) == split} + def get_lines_by_split(self, split: Literal['train', 'validation', 'test']): + return {k: v for k, v in self._lines.items() if v['tags'].get(key) == split} @property def tags(self): From 65ad0ea118eda06f35fb82cd2218781d83284f1a Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 6 Feb 2023 17:13:44 +0100 Subject: [PATCH 04/68] non-working xml parsing tests --- tests/test_xml.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/test_xml.py diff --git a/tests/test_xml.py b/tests/test_xml.py new file mode 100644 index 000000000..cee436ad8 --- /dev/null +++ b/tests/test_xml.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +import json +import unittest +import tempfile +import numpy as np + +from pathlib import Path +from pytest import raises + +from kraken.lib import xml + +thisfile = Path(__file__).resolve().parent +resources = thisfile / 'resources' + +class TestXMLParser(unittest.TestCase): + """ + Tests XML (ALTO/PAGE) parsing + """ + def setUp(self): + self.page_doc = resources / 'cPAS-2000.xml' + self.alto_doc = resources / 'bsb00084914_00007.xml' + + def test_page_parsing(self): + """ + Test parsing of PAGE XML files with reading order. + """ + doc = xml.XMLPage(self.page_doc, filetype='page') + self.assertEqual(len(doc.baselines), 97) + self.assertEqual(len([item for x in doc.regions.values() for item in x]), 4) + self.assertEqual( + + def test_alto_parsing(self): + """ + Test parsing of ALTO XML files with reading order. + """ + doc = xml.XMLPage(self.alto_doc, filetype='alto') + + def test_auto_parsing(self): + """ + Test parsing of PAGE and ALTO XML files with auto-format determination. + """ + doc = xml.XMLPage(self.page_doc, filetype='xml') + self.assertEqual(doc.filetype, 'page') + doc = xml.XMLPage(self.alto_doc, filetype='xml') + self.assertEqual(doc.filetype, 'alto') + + def test_failure_page_alto_parsing(self): + """ + Test that parsing ALTO files with PAGE as format fails. + """ + with raises(ValueError): + xml.XMLPage(self.alto_doc, filetype='page') + + def test_failure_alto_page_parsing(self): + """ + Test that parsing PAGE files with ALTO as format fails. + """ + with raises(ValueError): + xml.XMLPage(self.page_doc, filetype='alto') + From 2a1be8be8e115388351d5459f539a1d40e2e548f Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Wed, 22 Feb 2023 15:47:01 +0100 Subject: [PATCH 05/68] Fix ALTO region order parsing --- kraken/lib/xml.py | 43 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index 2d6f72e29..2f478cc54 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -17,6 +17,8 @@ """ import re import logging + +from os import PathLike from pathlib import Path from itertools import groupby @@ -480,7 +482,7 @@ class XMLPage(object): type: Literal['baselines', 'bbox'] == 'baselines' base_dir: Optional[Literal['L', 'R']] = None - imagename: pathlib.Path = None + imagename: PathLike = None _orders: Dict[str, Dict[str, Any]] = None has_tags: bool = False _tag_set: Optional[Dict] = None @@ -488,7 +490,7 @@ class XMLPage(object): _split_set: Optional[List] = None def __init__(self, - filename: Union[str, pathlib.Path], + filename: Union[str, PathLike], filetype: Literal['xml', 'alto', 'page'] = 'xml'): super().__init__() self.filename = Path(filename) @@ -653,13 +655,16 @@ def _parse_alto(self): reading_orders = ro_el.getchildren() # UnorderedGroup at top-level => treated as multiple reading orders if len(reading_orders) == 1 and reading_orders[0].tag.endswith('UnorderedGroup'): - reading_orders = reading_orders[0].getchildren() + reading_orders = reading_orders[0].getchildren() else: reading_orders = [reading_orders] + def _parse_group(el): + nonlocal is_valid + _ro = [] if el.tag.endswith('UnorderedGroup'): - _ro.append([_parse_group(x) for x in el.iterchildren()]) + _ro = [_parse_group(x) for x in el.iterchildren()] is_total = False elif el.tag.endswith('OrderedGroup'): _ro.extend(_parse_group(x) for x in el.iterchildren()) @@ -667,13 +672,15 @@ def _parse_group(el): ref = el.get('REF') res = doc.find(f'.//{{*}}*[@ID="{ref}"]') if res is None: - logger.warning(f'Nonexistant element with ID {ref} in reading order. Skipping RO {ro.get("ID")}.') + logger.warning(f'Nonexistent element with ID {ref} in reading order. Skipping RO {ro.get("ID")}.') is_valid = False return _ro tag = res.tag.split('}')[-1] if tag not in alto_regions.keys() and tag != 'TextLine': logger.warning(f'Sub-line element with ID {ref} in reading order. Skipping RO {ro.get("ID")}.') is_valid = False + return _ro + return ref return _ro for ro in reading_orders: @@ -839,7 +846,7 @@ def _parse_page(self): def _parse_group(el): _ro = [] if el.tag.endswith('UnorderedGroup'): - _ro.append([_parse_group(x) for x in el.iterchildren()]) + _ro = [_parse_group(x) for x in el.iterchildren()] is_total = False elif el.tag.endswith('OrderedGroup'): _ro.extend(_parse_group(x) for x in el.iterchildren()) @@ -887,8 +894,8 @@ def _traverse_ro(el): if el in self.lines: return self.lines[el] # substitute lines if region in RO - elif el in [reg['id'] for regs in doc.regions.values() for reg in regs]: - _ro.extend(self.get_lines_by_region(el)) + elif el in [reg['id'] for regs in self.regions.values() for reg in regs]: + _ro.extend(self.get_sorted_lines_by_region(el)) else: raise ValueError(f'Invalid reading order {ro}') return _ro @@ -900,12 +907,32 @@ def get_sorted_regions(self, ro='region_implicit'): """ Returns ordered regions from particular reading order. """ + if ro not in self.reading_orders: + raise ValueError(f'Unknown reading order {ro}') + regions = {reg['id']: key for key, regs in self.regions.items() for reg in regs} + + def _traverse_ro(el): + _ro = [] + if isinstance(el, list): + _ro.append([_traverse_ro(x) for x in el]) + else: + # if region directly append to ro + if el in regions.keys(): + return [reg for reg in self.regions[regions[el]] if reg['id'] == el][0] + else: + raise ValueError(f'Invalid reading order {ro}') + return _ro + + _ro = self.reading_orders[ro] + return _traverse_ro(_ro['order']) def get_sorted_lines_by_region(self, region, ro='line_implicit'): """ Returns ordered lines in region. """ + if ro not in self.reading_orders: + raise ValueError(f'Unknown reading order {ro}') if self.reading_orders[ro]['is_total'] is False: raise ValueError('Fetching lines by region of a non-total order is not supported') lines = [(id, line) for id, line in self._lines.items() if line['region'] == region] From bb4999ee7d7241d3f8daa86a6531ce8c47940a60 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 3 Mar 2023 12:45:38 +0100 Subject: [PATCH 06/68] working training code for RO --- kraken/lib/dataset/__init__.py | 1 + kraken/lib/dataset/ro.py | 140 ++++++++++++++++++++++++++++++ kraken/lib/default_specs.py | 23 +++++ kraken/lib/ro/__init__.py | 19 ++++ kraken/lib/ro/layers.py | 27 ++++++ kraken/lib/ro/model.py | 154 +++++++++++++++++++++++++++++++++ kraken/lib/ro/util.py | 66 ++++++++++++++ kraken/lib/xml.py | 8 +- 8 files changed, 434 insertions(+), 4 deletions(-) create mode 100644 kraken/lib/dataset/ro.py create mode 100644 kraken/lib/ro/__init__.py create mode 100644 kraken/lib/ro/layers.py create mode 100644 kraken/lib/ro/model.py create mode 100644 kraken/lib/ro/util.py diff --git a/kraken/lib/dataset/__init__.py b/kraken/lib/dataset/__init__.py index 960ef8499..5388cfe6b 100644 --- a/kraken/lib/dataset/__init__.py +++ b/kraken/lib/dataset/__init__.py @@ -17,4 +17,5 @@ """ from .recognition import ArrowIPCRecognitionDataset, PolygonGTDataset, GroundTruthDataset # NOQA from .segmentation import BaselineSet # NOQA +from .ro import ROSet #NOQA from .utils import ImageInputTransforms, collate_sequences, global_align, compute_confusions # NOQA diff --git a/kraken/lib/dataset/ro.py b/kraken/lib/dataset/ro.py new file mode 100644 index 000000000..0b81912d1 --- /dev/null +++ b/kraken/lib/dataset/ro.py @@ -0,0 +1,140 @@ +# +# Copyright 2015 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Utility functions for data loading and training of VGSL networks. +""" +import json +import torch +import traceback +import numpy as np +import torch.nn.functional as F +import shapely.geometry as geom + +from math import factorial +from os import path, PathLike +from PIL import Image +from shapely.ops import split +from itertools import groupby +from torchvision import transforms +from collections import defaultdict +from torch.utils.data import Dataset +from typing import Dict, List, Tuple, Sequence, Callable, Any, Union, Literal, Optional + +from kraken.lib.xml import parse_alto, parse_page, parse_xml, XMLPage + +from kraken.lib.exceptions import KrakenInputException + +__all__ = ['BaselineSet'] + +import logging + +logger = logging.getLogger(__name__) + + +class ROSet(Dataset): + """ + Dataset for training a reading order determination model. + """ + def __init__(self, files: Sequence[Union[PathLike, str]] = None, + mode: Optional[Literal['alto', 'page', 'xml']] = 'path', + ro_type: Literal['region', 'line'] = 'line', + ro_id: str = 'line_implicit', + class_mapping: Optional[Dict[str, int]] = None): + """ + Samples pairs lines/regions from XML files for training a reading order + model . + + Args: + mode: Either alto, page, xml, None. In alto, page, and xml + mode the baseline paths and image data is retrieved from an + ALTO/PageXML file. In `None` mode data is iteratively added + through the `add` method. + ro_id: ID of the reading order to sample from. + """ + super().__init__() + + self._num_pairs = 0 + if class_mapping: + self.class_mapping = class_mapping + self.num_classes = len(class_mapping) + 1 + else: + self.num_classes = 1 + self.class_mapping = {} + + self.data = [] + + if mode in ['alto', 'page', 'xml']: + for file in files: + try: + doc = XMLPage(file, filetype=mode) + for tag in doc.tags: + if tag not in self.class_mapping: + self.class_mapping[tag] = self.num_classes + self.num_classes += 1 + except KrakenInputException as e: + files.pop(file) + logger.warning(e) + continue + for file in files: + try: + doc = XMLPage(file, filetype=mode) + if ro_type == 'line': + order = doc.get_sorted_lines(ro_id) + elif ro_type == 'region': + order = doc.get_sorted_regions(ro_id) + else: + raise ValueError(f'Invalid RO type {ro_type}') + # traverse RO and substitute features. + h,w = Image.open(doc.imagename).size + sorted_lines = [] + for line in order: + line_coords = np.array(line['baseline']) / (w, h) + line_center = np.mean(line_coords, axis=0) + cl = torch.zeros(self.num_classes, dtype=torch.float) + # if class is not in class mapping default to None class (idx 0) + cl[self.class_mapping.get(line['tags']['type'], 0)] = 1 + line_data = {'type': line['tags']['type'], + 'features': torch.cat((cl, # one hot encoded line type + torch.tensor(line_center, dtype=torch.float), # line center + torch.tensor(line_coords[0, :], dtype=torch.float), # start_point coord + torch.tensor(line_coords[-1, :], dtype=torch.float), # end point coord) + )) + } + sorted_lines.append(line_data) + self.data.append(sorted_lines) + self._num_pairs += int(factorial(len(sorted_lines))/factorial(len(sorted_lines)-2)) + + except KrakenInputException as e: + logger.warning(e) + continue + else: + raise Exception('invalid dataset mode') + + def __getitem__(self, idx): + lines = [] + while len(lines) < 2: + lines = self.data[torch.randint(len(self.data), (1,))[0]] + idx0, idx1 = 0, 0 + while idx0 == idx1: + idx0, idx1 = torch.randint(len(lines), (2,)) + x = torch.cat((lines[idx0]['features'], lines[idx1]['features'])) + y = torch.tensor(0 if idx0 >= idx1 else 1, dtype=torch.float) + return {'sample': x, 'target': y} + + def get_feature_dim(self): + return 2 * self.num_classes + 12 + + def __len__(self): + return self._num_pairs diff --git a/kraken/lib/default_specs.py b/kraken/lib/default_specs.py index 4830ee1fb..bc92f6979 100644 --- a/kraken/lib/default_specs.py +++ b/kraken/lib/default_specs.py @@ -19,6 +19,29 @@ SEGMENTATION_SPEC = '[1,1800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]' # NOQA RECOGNITION_SPEC = '[1,120,0,1 Cr3,13,32 Do0.1,2 Mp2,2 Cr3,13,32 Do0.1,2 Mp2,2 Cr3,9,64 Do0.1,2 Mp2,2 Cr3,9,64 Do0.1,2 S1(1x0)1,3 Lbx200 Do0.1,2 Lbx200 Do0.1,2 Lbx200 Do]' # NOQA +READING_ORDER_HYPER_PARAMS = {'lrate': 0.001, + 'freq': 1.0, + 'batch_size': 15000, + 'epochs': 3000, + 'lag': 300, + 'quit': 'early', + 'optimizer': 'Adam', + 'momentum': 0.9, + 'weight_decay': 0.01, + 'schedule': 'cosine', + 'completed_epochs': 0, + # lr scheduler params + # step/exp decay + 'step_size': 10, + 'gamma': 0.1, + # reduce on plateau + 'rop_factor': 0.1, + 'rop_patience': 5, + # cosine + 'cos_t_max': 100, + 'warmup': 0, + } + RECOGNITION_PRETRAIN_HYPER_PARAMS = {'pad': 16, 'freq': 1.0, 'batch_size': 64, diff --git a/kraken/lib/ro/__init__.py b/kraken/lib/ro/__init__.py new file mode 100644 index 000000000..4e370b855 --- /dev/null +++ b/kraken/lib/ro/__init__.py @@ -0,0 +1,19 @@ +# +# Copyright 2023 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Tools for trainable reading order. +""" + +from .model import ROModel # NOQA diff --git a/kraken/lib/ro/layers.py b/kraken/lib/ro/layers.py new file mode 100644 index 000000000..c86cb0711 --- /dev/null +++ b/kraken/lib/ro/layers.py @@ -0,0 +1,27 @@ +""" +Layers for VGSL models +""" +from torch import nn + +# all tensors are ordered NCHW, the "feature" dimension is C, so the output of +# an LSTM will be put into C same as the filters of a CNN. + +__all__ = ['MLP'] + + +class MLP(nn.Module): + """ + A simple 2 layer MLP for reading order determination. + """ + def __init__(self, feature_size: int, hidden_size: int): + super(MLP, self).__init__() + self.fc1 = nn.Linear(feature_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return self.sigmoid(x) diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py new file mode 100644 index 000000000..63b0bf349 --- /dev/null +++ b/kraken/lib/ro/model.py @@ -0,0 +1,154 @@ +# +# Copyright 2023 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +Pytorch-lightning modules for reading order training. + +Adapted from: +""" +import re +import math +import torch +import logging +import numpy as np +import torch.nn.functional as F +import pytorch_lightning as pl + +from os import PathLike +from typing import Dict, Optional, Sequence, Union, Any, Literal + +from kraken.lib import vgsl, default_specs, layers +from kraken.lib.dataset import ROSet +from kraken.lib.train import _configure_optimizer_and_lr_scheduler +from kraken.lib.ro.layers import MLP + +from torch.utils.data import DataLoader, random_split, Subset + + +logger = logging.getLogger(__name__) + + +class ROModel(pl.LightningModule): + def __init__(self, + hyper_params: Dict[str, Any] = None, + output: str = 'model', + model: Optional[Union[PathLike, str]] = None, + training_data: Union[Sequence[Union[PathLike, str]], Sequence[Dict[str, Any]]] = None, + evaluation_data: Optional[Union[Sequence[Union[PathLike, str]], Sequence[Dict[str, Any]]]] = None, + partition: Optional[float] = 0.9, + num_workers: int = 1, + format_type: Literal['alto', 'page', 'xml'] = 'xml', + load_hyper_parameters: bool = False): + """ + A LightningModule encapsulating the unsupervised pretraining setup for + a text recognition model. + + Setup parameters (load, training_data, evaluation_data, ....) are + named, model hyperparameters (everything in + `kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS`) are in in the + `hyper_params` argument. + + Args: + hyper_params (dict): Hyperparameter dictionary containing all fields + from + kraken.lib.default_specs.RECOGNITION_PRETRAIN_HYPER_PARAMS + **kwargs: Setup parameters, i.e. CLI parameters of the train() command. + """ + super().__init__() + hyper_params_ = default_specs.READING_ORDER_HYPER_PARAMS + if model: + logger.info(f'Loading existing model from {model} ') + self.nn = vgsl.TorchVGSLModel.load_model(model) + + if self.nn.model_type not in [None, 'segmentation']: + raise ValueError(f'Model {model} is of type {self.nn.model_type} while `segmentation` is expected.') + + if load_hyper_parameters: + hp = self.nn.hyper_params + else: + hp = {} + hyper_params_.update(hp) + else: + self.ro_net = None + + if hyper_params: + hyper_params_.update(hyper_params) + self.save_hyperparameters(hyper_params_) + + if not evaluation_data: + np.random.shuffle(training_data) + training_data = training_data[:int(partition*len(training_data))] + evaluation_data = training_data[int(partition*len(training_data)):] + self.train_set = ROSet(training_data, mode=format_type) + self.val_set = ROSet(evaluation_data, mode=format_type, class_mapping=self.train_set.class_mapping) + + if len(self.train_set) == 0 or len(self.val_set) == 0: + raise ValueError('No valid training data was provided to the train ' + 'command. Please add valid XML, line, or binary data.') + + logger.info(f'Training set {len(self.train_set)} lines, validation set ' + f'{len(self.val_set)} lines') + + self.model = model + self.output = output + self.criterion = torch.nn.BCELoss() + + self.num_workers = num_workers + + self.best_epoch = 0 + self.best_metric = math.inf + + logger.info(f'Creating new RO model') + self.ro_net = torch.jit.script(MLP(self.train_set.get_feature_dim(), 128)) + + if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): + logger.debug('Setting multiprocessing tensor sharing strategy to file_system') + torch.multiprocessing.set_sharing_strategy('file_system') + + logger.info('Encoding training set') + + def forward(self, x): + return self.ro_net(x) + + def validation_step(self, batch, batch_idx): + x, y = batch['sample'], batch['target'] + yhat = self.ro_net(x) + loss = self.criterion(yhat.squeeze(), y) + self.log('loss', loss) + return loss + + def training_step(self, batch, batch_idx): + x, y = batch['sample'], batch['target'] + yhat = self.ro_net(x) + loss = self.criterion(yhat.squeeze(), y) + self.log('loss', loss) + return loss + + def configure_optimizers(self): + return _configure_optimizer_and_lr_scheduler(self.hparams, + self.ro_net.parameters(), + len_train_set=len(self.train_set), + loss_tracking_mode='min') + + def train_dataloader(self): + return DataLoader(self.train_set, + batch_size=self.hparams.batch_size, + num_workers=self.num_workers, + pin_memory=True) + + def val_dataloader(self): + return DataLoader(self.val_set, + batch_size=self.hparams.batch_size, + num_workers=self.num_workers, + pin_memory=True) diff --git a/kraken/lib/ro/util.py b/kraken/lib/ro/util.py new file mode 100644 index 000000000..57fea354b --- /dev/null +++ b/kraken/lib/ro/util.py @@ -0,0 +1,66 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Sequence, Union + +import torch +import random +import numpy as np + + +def positive_integers_with_sum(n, total): + ls = [0] + rv = [] + while len(ls) < n: + c = random.randint(0, total) + ls.append(c) + ls = sorted(ls) + ls.append(total) + for i in range(1, len(ls)): + rv.append(ls[i] - ls[i-1]) + return rv + + +def compute_masks(mask_prob: int, + mask_width: int, + num_neg_samples: int, + seq_lens: Union[torch.Tensor, Sequence[int]]): + """ + Samples num_mask non-overlapping random masks of length mask_width in + sequence of length seq_len. + + Args: + mask_prob: Probability of each individual token being chosen as start + of a masked sequence. Overall number of masks num_masks is + mask_prob * sum(seq_lens) / mask_width. + mask_width: width of each mask + num_neg_samples: Number of samples from unmasked sequence parts (gets + multiplied by num_mask) + seq_lens: sequence lengths + + Returns: + An index array containing 1 for masked bits, 2 for negative samples, + the number of masks, and the actual number of negative samples. + """ + mask_samples = np.zeros(sum(seq_lens)) + num_masks = int(mask_prob * sum(seq_lens.numpy()) // mask_width) + num_neg_samples = num_masks * num_neg_samples + num_masks += num_neg_samples + + indices = [x+mask_width for x in positive_integers_with_sum(num_masks, sum(seq_lens)-num_masks*mask_width)] + start = 0 + mask_slices = [] + for i in indices: + i_start = random.randint(start, i+start-mask_width) + mask_slices.append(slice(i_start, i_start+mask_width)) + start += i + + neg_idx = random.sample(range(len(mask_slices)), num_neg_samples) + neg_slices = [mask_slices.pop(idx) for idx in sorted(neg_idx, reverse=True)] + + mask_samples[np.r_[tuple(mask_slices)]] = 1 + mask_samples[np.r_[tuple(neg_slices)]] = 2 + + return mask_samples, num_masks - num_neg_samples, num_neg_samples diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index 2f478cc54..ebfd5c877 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -725,6 +725,8 @@ def _parse_page(self): valid_tr_lo = True for region in regions: + if not any([True if region.tag.endswith(k) else False for k in page_regions.keys()]): + continue coords = region.find('./{*}Coords') if coords is not None and not coords.get('points').isspace() and len(coords.get('points')): try: @@ -841,8 +843,6 @@ def _parse_page(self): # UnorderedGroup at top-level => treated as multiple reading orders if len(reading_orders) == 1 and reading_orders[0].tag.endswith('UnorderedGroup'): reading_orders = reading_orders.getchildren() - else: - reading_orders = [reading_orders] def _parse_group(el): _ro = [] if el.tag.endswith('UnorderedGroup'): @@ -888,7 +888,7 @@ def get_sorted_lines(self, ro='line_implicit'): def _traverse_ro(el): _ro = [] if isinstance(el, list): - _ro.append([_traverse_ro(x) for x in el]) + _ro = [_traverse_ro(x) for x in el] else: # if line directly append to ro if el in self.lines: @@ -915,7 +915,7 @@ def get_sorted_regions(self, ro='region_implicit'): def _traverse_ro(el): _ro = [] if isinstance(el, list): - _ro.append([_traverse_ro(x) for x in el]) + _ro = [_traverse_ro(x) for x in el] else: # if region directly append to ro if el in regions.keys(): From b06ebe7086a40227bd48880d8107d86bb2ed2045 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 3 Mar 2023 13:32:00 +0100 Subject: [PATCH 07/68] more training code --- docs/ketos.rst | 2 +- kraken/ketos/__init__.py | 2 + kraken/ketos/pretrain.py | 6 +- kraken/ketos/recognition.py | 6 +- kraken/ketos/ro.py | 249 +++++++++++++++++++++++++++++++++++ kraken/ketos/segmentation.py | 6 +- kraken/lib/dataset/ro.py | 18 ++- kraken/lib/default_specs.py | 3 +- kraken/lib/models.py | 2 +- kraken/lib/progress.py | 1 - kraken/lib/ro/model.py | 22 +++- 11 files changed, 291 insertions(+), 26 deletions(-) create mode 100644 kraken/ketos/ro.py diff --git a/docs/ketos.rst b/docs/ketos.rst index c3bd2926a..c96481390 100644 --- a/docs/ketos.rst +++ b/docs/ketos.rst @@ -142,7 +142,7 @@ option action -F, \--savefreq Model save frequency in epochs during training -q, \--quit Stop condition for training. Set to `early` - for early stopping (default) or `dumb` for fixed + for early stopping (default) or `fixed` for fixed number of epochs. -N, \--epochs Number of epochs to train for. \--min-epochs Minimum number of epochs to train for when using early stopping. diff --git a/kraken/ketos/__init__.py b/kraken/ketos/__init__.py index 83e56e82c..4b7087dc4 100644 --- a/kraken/ketos/__init__.py +++ b/kraken/ketos/__init__.py @@ -34,6 +34,7 @@ from .repo import publish from .segmentation import segtrain, segtest from .transcription import extract, transcription +from .ro import rotrain APP_NAME = 'kraken' @@ -76,6 +77,7 @@ def cli(ctx, verbose, seed, deterministic): cli.add_command(segtrain) cli.add_command(segtest) cli.add_command(publish) +cli.add_command(rotrain) # deprecated commands cli.add_command(line_generator) diff --git a/kraken/ketos/pretrain.py b/kraken/ketos/pretrain.py index 7be3cc2f0..aeebecca0 100644 --- a/kraken/ketos/pretrain.py +++ b/kraken/ketos/pretrain.py @@ -56,8 +56,8 @@ show_default=True, default=RECOGNITION_PRETRAIN_HYPER_PARAMS['quit'], type=click.Choice(['early', - 'dumb']), - help='Stop condition for training. Set to `early` for early stooping or `dumb` for fixed number of epochs') + 'fixed']), + help='Stop condition for training. Set to `early` for early stooping or `fixed` for fixed number of epochs') @click.option('-N', '--epochs', show_default=True, @@ -275,7 +275,7 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, trainer = KrakenTrainer(accelerator=accelerator, devices=device, precision=precision, - max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'dumb' else -1, + max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'fixed' else -1, min_epochs=hyper_params['min_epochs'], enable_progress_bar=True if not ctx.meta['verbose'] else False, deterministic=ctx.meta['deterministic'], diff --git a/kraken/ketos/recognition.py b/kraken/ketos/recognition.py index 781fe9f47..2d8eaf86b 100644 --- a/kraken/ketos/recognition.py +++ b/kraken/ketos/recognition.py @@ -55,8 +55,8 @@ show_default=True, default=RECOGNITION_HYPER_PARAMS['quit'], type=click.Choice(['early', - 'dumb']), - help='Stop condition for training. Set to `early` for early stooping or `dumb` for fixed number of epochs') + 'fixed']), + help='Stop condition for training. Set to `early` for early stooping or `fixed` for fixed number of epochs') @click.option('-N', '--epochs', show_default=True, @@ -302,7 +302,7 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, trainer = KrakenTrainer(accelerator=accelerator, devices=device, precision=precision, - max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'dumb' else -1, + max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'fixed' else -1, min_epochs=hyper_params['min_epochs'], freeze_backbone=hyper_params['freeze_backbone'], enable_progress_bar=True if not ctx.meta['verbose'] else False, diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py new file mode 100644 index 000000000..9d1222077 --- /dev/null +++ b/kraken/ketos/ro.py @@ -0,0 +1,249 @@ +# +# Copyright 2022 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +kraken.ketos.ro +~~~~~~~~~~~~~~~ + +Command line driver for reading order training, evaluation, and handling. +""" +import click +import pathlib +import logging + +from PIL import Image +from typing import Dict + +from kraken.lib.progress import KrakenProgressBar +from kraken.lib.exceptions import KrakenInputException +from kraken.lib.default_specs import READING_ORDER_HYPER_PARAMS + +from kraken.ketos.util import _validate_manifests, _expand_gt, message, to_ptl_device + +logging.captureWarnings(True) +logger = logging.getLogger('kraken') + +# raise default max image size to 20k * 20k pixels +Image.MAX_IMAGE_PIXELS = 20000 ** 2 + +@click.command('rotrain') +@click.pass_context +@click.option('-o', '--output', show_default=True, type=click.Path(), default='model', help='Output model file') +@click.option('-i', '--load', show_default=True, type=click.Path(exists=True, + readable=True), help='Load existing file to continue training') +@click.option('-F', '--freq', show_default=True, default=READING_ORDER_HYPER_PARAMS['freq'], type=click.FLOAT, + help='Model saving and report generation frequency in epochs ' + 'during training. If frequency is >1 it must be an integer, ' + 'i.e. running validation every n-th epoch.') +@click.option('-q', + '--quit', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['quit'], + type=click.Choice(['early', + 'fixed']), + help='Stop condition for training. Set to `early` for early stopping or `fixed` for fixed number of epochs') +@click.option('-N', + '--epochs', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['epochs'], + help='Number of epochs to train for') +@click.option('--min-epochs', + show_default=True, + default=['min_epochs'], + help='Minimal number of epochs to train for when using early stopping.') +@click.option('--lag', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['lag'], + help='Number of evaluations (--report frequence) to wait before stopping training without improvement') +@click.option('--min-delta', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['min_delta'], + type=click.FLOAT, + help='Minimum improvement between epochs to reset early stopping. By default it scales the delta by the best loss') +@click.option('-d', '--device', show_default=True, default='cpu', help='Select device to use (cpu, cuda:0, cuda:1, ...)') +@click.option('--precision', default='32', type=click.Choice(['32', '16']), help='set tensor precision') +@click.option('--optimizer', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['optimizer'], + type=click.Choice(['Adam', + 'SGD', + 'RMSprop', + 'Lamb']), + help='Select optimizer') +@click.option('-r', '--lrate', show_default=True, default=READING_ORDER_HYPER_PARAMS['lrate'], help='Learning rate') +@click.option('-m', '--momentum', show_default=True, default=READING_ORDER_HYPER_PARAMS['momentum'], help='Momentum') +@click.option('-w', '--weight-decay', show_default=True, + default=READING_ORDER_HYPER_PARAMS['weight_decay'], help='Weight decay') +@click.option('--warmup', show_default=True, type=float, + default=READING_ORDER_HYPER_PARAMS['warmup'], help='Number of samples to ramp up to `lrate` initial learning rate.') +@click.option('--schedule', + show_default=True, + type=click.Choice(['constant', + '1cycle', + 'exponential', + 'cosine', + 'step', + 'reduceonplateau']), + default=READING_ORDER_HYPER_PARAMS['schedule'], + help='Set learning rate scheduler. For 1cycle, cycle length is determined by the `--step-size` option.') +@click.option('-g', + '--gamma', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['gamma'], + help='Decay factor for exponential, step, and reduceonplateau learning rate schedules') +@click.option('-ss', + '--step-size', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['step_size'], + help='Number of validation runs between learning rate decay for exponential and step LR schedules') +@click.option('--sched-patience', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['rop_patience'], + help='Minimal number of validation runs between LR reduction for reduceonplateau LR schedule.') +@click.option('--cos-max', + show_default=True, + default=READING_ORDER_HYPER_PARAMS['cos_t_max'], + help='Epoch of minimal learning rate for cosine LR scheduler.') +@click.option('-p', '--partition', show_default=True, default=0.9, + help='Ground truth data partition ratio between train/validation set') +@click.option('-t', '--training-files', show_default=True, default=None, multiple=True, + callback=_validate_manifests, type=click.File(mode='r', lazy=True), + help='File(s) with additional paths to training data') +@click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, + callback=_validate_manifests, type=click.File(mode='r', lazy=True), + help='File(s) with paths to evaluation data. Overrides the `-p` parameter') +@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads and workers when running on CPU.') +@click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, + help='When loading an existing model, retrieve hyper-parameters from the model') +@click.option('-f', '--format-type', type=click.Choice(['xml', 'alto', 'page']), default='xml', + help='Sets the training data format. In ALTO and PageXML mode all ' + 'data is extracted from xml files containing both baselines and a ' + 'link to source images.') +@click.option('--logger', 'pl_logger', show_default=True, type=click.Choice(['tensorboard']), default=None, + help='Logger used by PyTorch Lightning to track metrics such as loss and accuracy.') +@click.option('--log-dir', show_default=True, type=click.Path(exists=True, dir_okay=True, writable=True), + help='Path to directory where the logger will store the logs. If not set, a directory will be created in the current working directory.') +@click.option('--level', show_default=True, type=click.Choice(['baselines', 'regions']), default='baselines', + help='Selects level to train reading order model on.') +@click.option('--reading-order', show_default=True, default=None, + help='Select reading order to train. Defaults to `line_implicit`/`region_implicit`') +@click.argument('ground_truth', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False)) +def rotrain(ctx, output, load, freq, quit, epochs, min_epochs, lag, + min_delta, device, precision, optimizer, lrate, momentum, + weight_decay, warmup, schedule, gamma, step_size, sched_patience, + cos_max, partition, training_files, evaluation_files, workers, + load_hyper_parameters, format_type, pl_logger, log_dir, level, + reading_order, ground_truth): + """ + Trains a baseline labeling model for layout analysis + """ + import shutil + + from kraken.lib.train import KrakenTrainer + from kraken.lib.ro import ROModel + + if not (0 <= freq <= 1) and freq % 1.0 != 0: + raise click.BadOptionUsage('freq', 'freq needs to be either in the interval [0,1.0] or a positive integer.') + + if pl_logger == 'tensorboard': + try: + import tensorboard + except ImportError: + raise click.BadOptionUsage('logger', 'tensorboard logger needs the `tensorboard` package installed.') + + if log_dir is None: + log_dir = pathlib.Path.cwd() + + logger.info('Building ground truth set from {} document images'.format(len(ground_truth) + len(training_files))) + + # populate hyperparameters from command line args + hyper_params = READING_ORDER_HYPER_PARAMS.copy() + hyper_params.update({'freq': freq, + 'quit': quit, + 'epochs': epochs, + 'min_epochs': min_epochs, + 'lag': lag, + 'min_delta': min_delta, + 'optimizer': optimizer, + 'lrate': lrate, + 'momentum': momentum, + 'weight_decay': weight_decay, + 'warmup': warmup, + 'schedule': schedule, + 'gamma': gamma, + 'step_size': step_size, + 'rop_patience': sched_patience, + 'cos_t_max': cos_max, + 'pl_logger': pl_logger,}) + + # disable automatic partition when given evaluation set explicitly + if evaluation_files: + partition = 1 + ground_truth = list(ground_truth) + + # merge training_files into ground_truth list + if training_files: + ground_truth.extend(training_files) + + if len(ground_truth) == 0: + raise click.UsageError('No training data was provided to the train command. Use `-t` or the `ground_truth` argument.') + + try: + accelerator, device = to_ptl_device(device) + except Exception as e: + raise click.BadOptionUsage('device', str(e)) + + if hyper_params['freq'] > 1: + val_check_interval = {'check_val_every_n_epoch': int(hyper_params['freq'])} + else: + val_check_interval = {'val_check_interval': hyper_params['freq']} + + model = ROModel(hyper_params, + output=output, + model=load, + training_data=ground_truth, + evaluation_data=evaluation_files, + partition=partition, + num_workers=workers, + load_hyper_parameters=load_hyper_parameters, + format_type=format_type, + level=level, + reading_order=reading_order) + + message(f'Training RO on following {level} types:') + for k, v in model.train_set.dataset.class_mapping.items(): + message(f' {k}\t{v}') + + if len(model.train_set) == 0: + raise click.UsageError('No valid training data was provided to the train command. Use `-t` or the `ground_truth` argument.') + + trainer = KrakenTrainer(accelerator=accelerator, + devices=device, + max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'fixed' else -1, + min_epochs=hyper_params['min_epochs'], + enable_progress_bar=True if not ctx.meta['verbose'] else False, + deterministic=ctx.meta['deterministic'], + precision=int(precision), + pl_logger=pl_logger, + log_dir=log_dir, + **val_check_interval) + + trainer.fit(model) + + if quit == 'early': + message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format( + output, model.best_epoch, model.best_metric)) + logger.info('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format( + output, model.best_epoch, model.best_metric)) + shutil.copy(f'{output}_{model.best_epoch}.mlmodel', f'{output}_best.mlmodel') diff --git a/kraken/ketos/segmentation.py b/kraken/ketos/segmentation.py index 54afa9ae5..dfcb7c739 100644 --- a/kraken/ketos/segmentation.py +++ b/kraken/ketos/segmentation.py @@ -76,8 +76,8 @@ def _validate_merging(ctx, param, value): show_default=True, default=SEGMENTATION_HYPER_PARAMS['quit'], type=click.Choice(['early', - 'dumb']), - help='Stop condition for training. Set to `early` for early stopping or `dumb` for fixed number of epochs') + 'fixed']), + help='Stop condition for training. Set to `early` for early stopping or `fixed` for fixed number of epochs') @click.option('-N', '--epochs', show_default=True, @@ -339,7 +339,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, trainer = KrakenTrainer(accelerator=accelerator, devices=device, precision=precision, - max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'dumb' else -1, + max_epochs=hyper_params['epochs'] if hyper_params['quit'] == 'fixed' else -1, min_epochs=hyper_params['min_epochs'], enable_progress_bar=True if not ctx.meta['verbose'] else False, deterministic=ctx.meta['deterministic'], diff --git a/kraken/lib/dataset/ro.py b/kraken/lib/dataset/ro.py index 0b81912d1..88bb32b73 100644 --- a/kraken/lib/dataset/ro.py +++ b/kraken/lib/dataset/ro.py @@ -49,8 +49,8 @@ class ROSet(Dataset): """ def __init__(self, files: Sequence[Union[PathLike, str]] = None, mode: Optional[Literal['alto', 'page', 'xml']] = 'path', - ro_type: Literal['region', 'line'] = 'line', - ro_id: str = 'line_implicit', + level: Literal['regions', 'baselines'] = 'baselines', + ro_id: Optional[str] = None, class_mapping: Optional[Dict[str, int]] = None): """ Samples pairs lines/regions from XML files for training a reading order @@ -61,11 +61,13 @@ def __init__(self, files: Sequence[Union[PathLike, str]] = None, mode the baseline paths and image data is retrieved from an ALTO/PageXML file. In `None` mode data is iteratively added through the `add` method. - ro_id: ID of the reading order to sample from. + ro_id: ID of the reading order to sample from. Defaults to + `line_implicit`/`region_implicit`. """ super().__init__() self._num_pairs = 0 + self.failed_samples = [] if class_mapping: self.class_mapping = class_mapping self.num_classes = len(class_mapping) + 1 @@ -90,12 +92,16 @@ def __init__(self, files: Sequence[Union[PathLike, str]] = None, for file in files: try: doc = XMLPage(file, filetype=mode) - if ro_type == 'line': + if level == 'baselines': + if not ro_id: + ro_id = 'line_implicit' order = doc.get_sorted_lines(ro_id) - elif ro_type == 'region': + elif level == 'regions': + if not ro_id: + ro_id = 'region_implicit' order = doc.get_sorted_regions(ro_id) else: - raise ValueError(f'Invalid RO type {ro_type}') + raise ValueError(f'Invalid RO type {level}') # traverse RO and substitute features. h,w = Image.open(doc.imagename).size sorted_lines = [] diff --git a/kraken/lib/default_specs.py b/kraken/lib/default_specs.py index bc92f6979..af08fd1e5 100644 --- a/kraken/lib/default_specs.py +++ b/kraken/lib/default_specs.py @@ -24,6 +24,7 @@ 'batch_size': 15000, 'epochs': 3000, 'lag': 300, + 'min_delta': None, 'quit': 'early', 'optimizer': 'Adam', 'momentum': 0.9, @@ -107,7 +108,7 @@ SEGMENTATION_HYPER_PARAMS = {'line_width': 8, 'padding': (0, 0), 'freq': 1.0, - 'quit': 'dumb', + 'quit': 'fixed', 'epochs': 50, 'min_epochs': 0, 'lag': 10, diff --git a/kraken/lib/models.py b/kraken/lib/models.py index c8e08dcfb..ac1219b23 100644 --- a/kraken/lib/models.py +++ b/kraken/lib/models.py @@ -217,6 +217,6 @@ def validate_hyper_parameters(hyper_params): """ Validate some model's hyper parameters and modify them in place if need be. """ - if (hyper_params['quit'] == 'dumb' and hyper_params['completed_epochs'] >= hyper_params['epochs']): + if (hyper_params['quit'] == 'fixed' and hyper_params['completed_epochs'] >= hyper_params['epochs']): logger.warning('Maximum epochs reached (might be loaded from given model), starting again from 0.') hyper_params['completed_epochs'] = 0 diff --git a/kraken/lib/progress.py b/kraken/lib/progress.py index bb2e30b3e..0284fa865 100644 --- a/kraken/lib/progress.py +++ b/kraken/lib/progress.py @@ -155,4 +155,3 @@ class RichProgressBarTheme: time: Union[str, Style] = DEFAULT_STYLES['progress.elapsed'] processing_speed: Union[str, Style] = DEFAULT_STYLES['progress.data.speed'] metrics: Union[str, Style] = DEFAULT_STYLES['progress.description'] - diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index 63b0bf349..e5a482acd 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -49,7 +49,9 @@ def __init__(self, partition: Optional[float] = 0.9, num_workers: int = 1, format_type: Literal['alto', 'page', 'xml'] = 'xml', - load_hyper_parameters: bool = False): + load_hyper_parameters: bool = False, + level: Literal['baselines', 'regions'] = 'baselines', + reading_order: Optional[str] = None): """ A LightningModule encapsulating the unsupervised pretraining setup for a text recognition model. @@ -90,8 +92,17 @@ def __init__(self, np.random.shuffle(training_data) training_data = training_data[:int(partition*len(training_data))] evaluation_data = training_data[int(partition*len(training_data)):] - self.train_set = ROSet(training_data, mode=format_type) - self.val_set = ROSet(evaluation_data, mode=format_type, class_mapping=self.train_set.class_mapping) + train_set = ROSet(training_data, + mode=format_type, + level=level, + ro_id=reading_order) + self.train_set = Subset(train_set, range(len(train_set))) + val_set = ROSet(evaluation_data, + mode=format_type, + class_mapping=train_set.class_mapping, + level=level, + ro_id=reading_order) + self.val_set = Subset(val_set, range(len(val_set))) if len(self.train_set) == 0 or len(self.val_set) == 0: raise ValueError('No valid training data was provided to the train ' @@ -106,11 +117,8 @@ def __init__(self, self.num_workers = num_workers - self.best_epoch = 0 - self.best_metric = math.inf - logger.info(f'Creating new RO model') - self.ro_net = torch.jit.script(MLP(self.train_set.get_feature_dim(), 128)) + self.ro_net = torch.jit.script(MLP(train_set.get_feature_dim(), 128)) if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug('Setting multiprocessing tensor sharing strategy to file_system') From 145516b1d0926260ffaed9982b1d515c67f0c7b6 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 3 Mar 2023 14:19:11 +0100 Subject: [PATCH 08/68] remove metric ignore code in progress bar --- kraken/ketos/pretrain.py | 2 +- kraken/ketos/ro.py | 1 - kraken/lib/ro/model.py | 42 ++++++++++++++++++++++------------------ 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/kraken/ketos/pretrain.py b/kraken/ketos/pretrain.py index aeebecca0..1c4bc3d0e 100644 --- a/kraken/ketos/pretrain.py +++ b/kraken/ketos/pretrain.py @@ -279,7 +279,7 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, min_epochs=hyper_params['min_epochs'], enable_progress_bar=True if not ctx.meta['verbose'] else False, deterministic=ctx.meta['deterministic'], - pb_ignored_metrics=(), + failed_sample_threshold=failed_sample_threshold, **val_check_interval) trainer.fit(model, datamodule=data_module) diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py index 9d1222077..d4ff1a409 100644 --- a/kraken/ketos/ro.py +++ b/kraken/ketos/ro.py @@ -211,7 +211,6 @@ def rotrain(ctx, output, load, freq, quit, epochs, min_epochs, lag, model = ROModel(hyper_params, output=output, - model=load, training_data=ground_truth, evaluation_data=evaluation_files, partition=partition, diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index e5a482acd..f8e607291 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -26,7 +26,9 @@ import pytorch_lightning as pl from os import PathLike -from typing import Dict, Optional, Sequence, Union, Any, Literal +from dataclasses import dataclass, field +from torch.nn import Module +from typing import Dict, Optional, Sequence, Union, Any, Literal, List from kraken.lib import vgsl, default_specs, layers from kraken.lib.dataset import ROSet @@ -38,12 +40,26 @@ logger = logging.getLogger(__name__) +@dataclass +class DummyVGSLModel: + hyper_params: Dict[str, int] = field(default_factory=dict) + user_metadata: Dict[str, List] = field(default_factory=dict) + one_channel_mode: Literal['1', 'L'] = '1' + ptl_module: Module = None + model_type: str = 'unknown' + + def __post_init__(self): + self.hyper_params: Dict[str, int] = {'completed_epochs': 0} + self.user_metadata: Dict[str, List] = {'accuracy': [], 'metrics': []} + + def save_model(self, filename): + self.ptl_module.save_checkpoint(filename) + class ROModel(pl.LightningModule): def __init__(self, hyper_params: Dict[str, Any] = None, output: str = 'model', - model: Optional[Union[PathLike, str]] = None, training_data: Union[Sequence[Union[PathLike, str]], Sequence[Dict[str, Any]]] = None, evaluation_data: Optional[Union[Sequence[Union[PathLike, str]], Sequence[Dict[str, Any]]]] = None, partition: Optional[float] = 0.9, @@ -69,20 +85,6 @@ def __init__(self, """ super().__init__() hyper_params_ = default_specs.READING_ORDER_HYPER_PARAMS - if model: - logger.info(f'Loading existing model from {model} ') - self.nn = vgsl.TorchVGSLModel.load_model(model) - - if self.nn.model_type not in [None, 'segmentation']: - raise ValueError(f'Model {model} is of type {self.nn.model_type} while `segmentation` is expected.') - - if load_hyper_parameters: - hp = self.nn.hyper_params - else: - hp = {} - hyper_params_.update(hp) - else: - self.ro_net = None if hyper_params: hyper_params_.update(hyper_params) @@ -111,7 +113,6 @@ def __init__(self, logger.info(f'Training set {len(self.train_set)} lines, validation set ' f'{len(self.val_set)} lines') - self.model = model self.output = output self.criterion = torch.nn.BCELoss() @@ -124,7 +125,7 @@ def __init__(self, logger.debug('Setting multiprocessing tensor sharing strategy to file_system') torch.multiprocessing.set_sharing_strategy('file_system') - logger.info('Encoding training set') + self.nn = DummyVGSLModel(ptl_module=self) def forward(self, x): return self.ro_net(x) @@ -133,7 +134,7 @@ def validation_step(self, batch, batch_idx): x, y = batch['sample'], batch['target'] yhat = self.ro_net(x) loss = self.criterion(yhat.squeeze(), y) - self.log('loss', loss) + self.log('val_metric', loss) return loss def training_step(self, batch, batch_idx): @@ -160,3 +161,6 @@ def val_dataloader(self): batch_size=self.hparams.batch_size, num_workers=self.num_workers, pin_memory=True) + + def save_checkpoint(self, filename): + self.trainer.save_checkpoint(filename) From 46d2c96a401891f6952695f4b30853da7a83a55a Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 3 Mar 2023 15:17:14 +0100 Subject: [PATCH 09/68] make checkpoint loading work for RO training --- kraken/ketos/ro.py | 29 ++++++---- kraken/lib/ro/model.py | 26 ++++++--- kraken/lib/train.py | 123 +++++++++++++++++++++++------------------ 3 files changed, 107 insertions(+), 71 deletions(-) diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py index d4ff1a409..b9ae52f55 100644 --- a/kraken/ketos/ro.py +++ b/kraken/ketos/ro.py @@ -209,16 +209,25 @@ def rotrain(ctx, output, load, freq, quit, epochs, min_epochs, lag, else: val_check_interval = {'val_check_interval': hyper_params['freq']} - model = ROModel(hyper_params, - output=output, - training_data=ground_truth, - evaluation_data=evaluation_files, - partition=partition, - num_workers=workers, - load_hyper_parameters=load_hyper_parameters, - format_type=format_type, - level=level, - reading_order=reading_order) + if load: + model = ROModel.load_from_checkpoint(load, + training_data=ground_truth, + evaluation_data=evaluation_files, + partition=partition, + num_workers=workers, + load_hyper_parameters=load_hyper_parameters, + format_type=format_type) + else: + model = ROModel(hyper_params, + output=output, + training_data=ground_truth, + evaluation_data=evaluation_files, + partition=partition, + num_workers=workers, + load_hyper_parameters=load_hyper_parameters, + format_type=format_type, + level=level, + reading_order=reading_order) message(f'Training RO on following {level} types:') for k, v in model.train_set.dataset.class_mapping.items(): diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index f8e607291..c878f4a83 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -84,11 +84,9 @@ def __init__(self, **kwargs: Setup parameters, i.e. CLI parameters of the train() command. """ super().__init__() - hyper_params_ = default_specs.READING_ORDER_HYPER_PARAMS - + self.hyper_params = default_specs.READING_ORDER_HYPER_PARAMS if hyper_params: - hyper_params_.update(hyper_params) - self.save_hyperparameters(hyper_params_) + self.hyper_params.update(hyper_params) if not evaluation_data: np.random.shuffle(training_data) @@ -118,6 +116,9 @@ def __init__(self, self.num_workers = num_workers + self.best_epoch = -1 + self.best_metric = torch.inf + logger.info(f'Creating new RO model') self.ro_net = torch.jit.script(MLP(train_set.get_feature_dim(), 128)) @@ -127,6 +128,8 @@ def __init__(self, self.nn = DummyVGSLModel(ptl_module=self) + self.save_hyperparameters() + def forward(self, x): return self.ro_net(x) @@ -137,6 +140,15 @@ def validation_step(self, batch, batch_idx): self.log('val_metric', loss) return loss + def validation_epoch_end(self, outputs): + val_metric = np.mean([x.cpu() for x in outputs]) + if val_metric < self.best_metric: + logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {val_metric} ({self.current_epoch})') + self.best_epoch = self.current_epoch + self.best_metric = val_metric + logger.info(f'validation run: val_metric {val_metric}') + self.log('val_metric', val_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True) + def training_step(self, batch, batch_idx): x, y = batch['sample'], batch['target'] yhat = self.ro_net(x) @@ -145,20 +157,20 @@ def training_step(self, batch, batch_idx): return loss def configure_optimizers(self): - return _configure_optimizer_and_lr_scheduler(self.hparams, + return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, self.ro_net.parameters(), len_train_set=len(self.train_set), loss_tracking_mode='min') def train_dataloader(self): return DataLoader(self.train_set, - batch_size=self.hparams.batch_size, + batch_size=self.hyper_params['batch_size'], num_workers=self.num_workers, pin_memory=True) def val_dataloader(self): return DataLoader(self.val_set, - batch_size=self.hparams.batch_size, + batch_size=self.hyper_params['batch_size'], num_workers=self.num_workers, pin_memory=True) diff --git a/kraken/lib/train.py b/kraken/lib/train.py index be9b70a31..5f4af6e17 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -246,7 +246,8 @@ def __init__(self, if hyper_params: hyper_params_.update(hyper_params) - self.save_hyperparameters(hyper_params_) + self.hyper_params = hyper_params_ + self.save_hyperparameters() self.reorder = reorder self.append = append @@ -415,11 +416,11 @@ def _build_dataset(self, DatasetClass, training_data, **kwargs): - dataset = DatasetClass(normalization=self.hparams.normalization, - whitespace_normalization=self.hparams.normalize_whitespace, + dataset = DatasetClass(normalization=self.hparams.hyper_params['normalization'], + whitespace_normalization=self.hparams.hyper_params['normalize_whitespace'], reorder=self.reorder, im_transforms=self.transforms, - augmentation=self.hparams.augment, + augmentation=self.hparams.hyper_params['augment'], **kwargs) if (self.num_workers and self.num_workers > 1) and self.format_type != 'binary': @@ -434,7 +435,7 @@ def _build_dataset(self, dataset.add(**im) except KrakenInputException as e: logger.warning(str(e)) - if self.format_type == 'binary' and self.hparams.normalization: + if self.format_type == 'binary' and self.hparams.hyper_params['normalization']: logger.debug('Rebuilding dataset using unicode normalization') dataset.rebuild_alphabet() return dataset @@ -591,7 +592,7 @@ def setup(self, stage: Optional[str] = None): if self.format_type != 'path' and self.nn.seg_type == 'bbox': logger.warning('Neural network has been trained on bounding box image information but training set is polygonal.') - self.nn.hyper_params = self.hparams + self.nn.hyper_params = self.hparams.hyper_params self.nn.model_type = 'recognition' if not self.nn.seg_type: @@ -605,7 +606,7 @@ def setup(self, stage: Optional[str] = None): def train_dataloader(self): return DataLoader(self.train_set, - batch_size=self.hparams.batch_size, + batch_size=self.hparams.hyper_params['batch_size'], num_workers=self.num_workers, pin_memory=True, shuffle=True, @@ -614,7 +615,7 @@ def train_dataloader(self): def val_dataloader(self): return DataLoader(self.val_set, shuffle=False, - batch_size=self.hparams.batch_size, + batch_size=self.hparams.hyper_params['batch_size'], num_workers=self.num_workers, pin_memory=True, collate_fn=collate_sequences, @@ -622,11 +623,12 @@ def val_dataloader(self): def configure_callbacks(self): callbacks = [] - if self.hparams.quit == 'early': + if self.hparams.hyper_params['quit'] == 'early': callbacks.append(EarlyStopping(monitor='val_accuracy', mode='max', - patience=self.hparams.lag, + patience=self.hparams.hyper_params['lag'], stopping_threshold=1.0)) + return callbacks # configuration of optimizers and learning rate schedulers @@ -636,7 +638,7 @@ def configure_callbacks(self): # batch-wise learning rate warmup. In lr_scheduler_step() calls to the # scheduler are then only performed at the end of the epoch. def configure_optimizers(self): - return _configure_optimizer_and_lr_scheduler(self.hparams, + return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, self.nn.nn.parameters(), len_train_set=len(self.train_set), loss_tracking_mode='max') @@ -648,13 +650,13 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # linear warmup between 0 and the initial learning rate `lrate` in `warmup` # steps. - if self.hparams.warmup and self.trainer.global_step < self.hparams.warmup: - lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.warmup) + if self.hparams.hyper_params.warmup and self.trainer.global_step < self.hparams.hyper_params.warmup: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params.warmup) for pg in optimizer.param_groups: - pg["lr"] = lr_scale * self.hparams.lrate + pg["lr"] = lr_scale * self.hparams.hyper_params.lrate - def lr_scheduler_step(self, scheduler, metric): - if not self.hparams.warmup or self.trainer.global_step >= self.hparams.warmup: + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + if not self.hparams.hyper_params.warmup or self.trainer.global_step >= self.hparams.hyper_params.warmup: # step OneCycleLR each batch if not in warmup phase if isinstance(scheduler, lr_scheduler.OneCycleLR): scheduler.step() @@ -759,7 +761,8 @@ def __init__(self, hyper_params_.update(hyper_params) validate_hyper_parameters(hyper_params_) - self.save_hyperparameters(hyper_params_) + self.hyper_params = hyper_params_ + self.save_hyperparameters() if not training_data: raise ValueError('No training data provided. Please add some.') @@ -768,7 +771,7 @@ def __init__(self, height, width, channels, - self.hparams.padding, + self.hparams.hyper_params.padding, valid_norm=False, force_binarization=force_binarization) @@ -795,10 +798,10 @@ def __init__(self, merge_baselines = None train_set = BaselineSet(training_data, - line_width=self.hparams.line_width, + line_width=self.hparams.hyper_params.line_width, im_transforms=transforms, mode=format_type, - augmentation=self.hparams.augment, + augmentation=self.hparams.hyper_params.augment, valid_baselines=valid_baselines, merge_baselines=merge_baselines, valid_regions=valid_regions, @@ -810,7 +813,7 @@ def __init__(self, if evaluation_data: val_set = BaselineSet(evaluation_data, - line_width=self.hparams.line_width, + line_width=self.hparams.hyper_params.line_width, im_transforms=transforms, mode=format_type, augmentation=False, @@ -1035,10 +1038,10 @@ def val_dataloader(self): def configure_callbacks(self): callbacks = [] - if self.hparams.quit == 'early': + if self.hparams.hyper_params['quit'] == 'early': callbacks.append(EarlyStopping(monitor='val_mean_iu', mode='max', - patience=self.hparams.lag, + patience=self.hparams.hyper_params['lag'], stopping_threshold=1.0)) return callbacks @@ -1050,7 +1053,7 @@ def configure_callbacks(self): # batch-wise learning rate warmup. In lr_scheduler_step() calls to the # scheduler are then only performed at the end of the epoch. def configure_optimizers(self): - return _configure_optimizer_and_lr_scheduler(self.hparams, + return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, self.nn.nn.parameters(), len_train_set=len(self.train_set), loss_tracking_mode='max') @@ -1061,13 +1064,13 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # linear warmup between 0 and the initial learning rate `lrate` in `warmup` # steps. - if self.hparams.warmup and self.trainer.global_step < self.hparams.warmup: - lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.warmup) + if self.hparams.hyper_params['warmup'] and self.trainer.global_step < self.hparams.hyper_params['warmup']: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params['warmup']) for pg in optimizer.param_groups: - pg["lr"] = lr_scale * self.hparams.lrate + pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] - def lr_scheduler_step(self, scheduler, metric): - if not self.hparams.warmup or self.trainer.global_step >= self.hparams.warmup: + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: # step OneCycleLR each batch if not in warmup phase if isinstance(scheduler, lr_scheduler.OneCycleLR): scheduler.step() @@ -1077,51 +1080,63 @@ def lr_scheduler_step(self, scheduler, metric): def _configure_optimizer_and_lr_scheduler(hparams, params, len_train_set=None, loss_tracking_mode='max'): + optimizer = hparams.get("optimizer") + lrate = hparams.get("lrate") + momentum = hparams.get("momentum") + weight_decay = hparams.get("weight_decay") + schedule = hparams.get("schedule") + gamma = hparams.get("gamma") + step_size = hparams.get("step_size") + rop_factor = hparams.get("rop_factor") + rop_patience = hparams.get("rop_patience") + epochs = hparams.get("epochs") + completed_epochs = hparams.get("completed_epochs") + # XXX: Warmup is not configured here because it needs to be manually done in optimizer_step() - logger.debug(f'Constructing {hparams.optimizer} optimizer (lr: {hparams.lrate}, momentum: {hparams.momentum})') - if hparams.optimizer == 'Adam': - optim = torch.optim.Adam(params, lr=hparams.lrate, weight_decay=hparams.weight_decay) + logger.debug(f'Constructing {optimizer} optimizer (lr: {lrate}, momentum: {momentum})') + if optimizer == 'Adam': + optim = torch.optim.Adam(params, lr=lrate, weight_decay=weight_decay) else: - optim = getattr(torch.optim, hparams.optimizer)(params, - lr=hparams.lrate, - momentum=hparams.momentum, - weight_decay=hparams.weight_decay) + optim = getattr(torch.optim, optimizer)(params, + lr=lrate, + momentum=momentum, + weight_decay=weight_decay) lr_sched = {} - if hparams.schedule == 'exponential': - lr_sched = {'scheduler': lr_scheduler.ExponentialLR(optim, hparams.gamma, last_epoch=hparams.completed_epochs-1), + if schedule == 'exponential': + lr_sched = {'scheduler': lr_scheduler.ExponentialLR(optim, gamma, last_epoch=completed_epochs-1), 'interval': 'step'} - elif hparams.schedule == 'cosine': - lr_sched = {'scheduler': lr_scheduler.CosineAnnealingLR(optim, hparams.gamma, last_epoch=hparams.completed_epochs-1), + elif schedule == 'cosine': + lr_sched = {'scheduler': lr_scheduler.CosineAnnealingLR(optim, gamma, last_epoch=completed_epochs-1), 'interval': 'step'} - elif hparams.schedule == 'step': - lr_sched = {'scheduler': lr_scheduler.StepLR(optim, hparams.step_size, hparams.gamma, last_epoch=hparams.completed_epochs-1), + elif schedule == 'step': + lr_sched = {'scheduler': lr_scheduler.StepLR(optim, step_size, gamma, last_epoch=completed_epochs-1), 'interval': 'step'} - elif hparams.schedule == 'reduceonplateau': + elif schedule == 'reduceonplateau': lr_sched = {'scheduler': lr_scheduler.ReduceLROnPlateau(optim, mode=loss_tracking_mode, - factor=hparams.rop_factor, - patience=hparams.rop_patience), + factor=rop_factor, + patience=rop_patience), 'interval': 'step'} - elif hparams.schedule == '1cycle': - if hparams.epochs <= 0: + elif schedule == '1cycle': + if epochs <= 0: raise ValueError('1cycle learning rate scheduler selected but ' 'number of epochs is less than 0 ' - f'({hparams.epochs}).') - last_epoch = hparams.completed_epochs*len_train_set if hparams.completed_epochs else -1 + f'({epochs}).') + last_epoch = completed_epochs*len_train_set if completed_epochs else -1 lr_sched = {'scheduler': lr_scheduler.OneCycleLR(optim, - max_lr=hparams.lrate, - epochs=hparams.epochs, + max_lr=lrate, + epochs=epochs, steps_per_epoch=len_train_set, last_epoch=last_epoch), 'interval': 'step'} - elif hparams.schedule != 'constant': - raise ValueError(f'Unsupported learning rate scheduler {hparams.schedule}.') + elif schedule != 'constant': + raise ValueError(f'Unsupported learning rate scheduler {schedule}.') ret = {'optimizer': optim} if lr_sched: ret['lr_scheduler'] = lr_sched - if hparams.schedule == 'reduceonplateau': + if schedule == 'reduceonplateau': lr_sched['monitor'] = 'val_metric' lr_sched['strict'] = False lr_sched['reduce_on_plateau'] = True From c13b72c22a9eb07d4887ae3687628c5a5db482ca Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 3 Mar 2023 16:07:28 +0100 Subject: [PATCH 10/68] add batch_size arg to ro cli again --- kraken/ketos/ro.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py index b9ae52f55..a58ce8d89 100644 --- a/kraken/ketos/ro.py +++ b/kraken/ketos/ro.py @@ -39,6 +39,8 @@ @click.command('rotrain') @click.pass_context +@click.option('-B', '--batch-size', show_default=True, type=click.INT, + default=RECOGNITION_PRETRAIN_HYPER_PARAMS['batch_size'], help='batch sample size') @click.option('-o', '--output', show_default=True, type=click.Path(), default='model', help='Output model file') @click.option('-i', '--load', show_default=True, type=click.Path(exists=True, readable=True), help='Load existing file to continue training') @@ -139,7 +141,7 @@ @click.option('--reading-order', show_default=True, default=None, help='Select reading order to train. Defaults to `line_implicit`/`region_implicit`') @click.argument('ground_truth', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False)) -def rotrain(ctx, output, load, freq, quit, epochs, min_epochs, lag, +def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag, min_delta, device, precision, optimizer, lrate, momentum, weight_decay, warmup, schedule, gamma, step_size, sched_patience, cos_max, partition, training_files, evaluation_files, workers, @@ -169,7 +171,8 @@ def rotrain(ctx, output, load, freq, quit, epochs, min_epochs, lag, # populate hyperparameters from command line args hyper_params = READING_ORDER_HYPER_PARAMS.copy() - hyper_params.update({'freq': freq, + hyper_params.update({'batch_size': batch_size, + 'freq': freq, 'quit': quit, 'epochs': epochs, 'min_epochs': min_epochs, From 92dfb074d841eba8d29351a7d61f0bfef3d63c19 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 3 Mar 2023 16:08:44 +0100 Subject: [PATCH 11/68] arrgh --- kraken/ketos/ro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py index a58ce8d89..27e4851be 100644 --- a/kraken/ketos/ro.py +++ b/kraken/ketos/ro.py @@ -40,7 +40,7 @@ @click.command('rotrain') @click.pass_context @click.option('-B', '--batch-size', show_default=True, type=click.INT, - default=RECOGNITION_PRETRAIN_HYPER_PARAMS['batch_size'], help='batch sample size') + default=READING_ORDER_HYPER_PARAMS['batch_size'], help='batch sample size') @click.option('-o', '--output', show_default=True, type=click.Path(), default='model', help='Output model file') @click.option('-i', '--load', show_default=True, type=click.Path(exists=True, readable=True), help='Load existing file to continue training') From e1a932b95d7086d396b03339abcfefdd3563537d Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 3 Mar 2023 16:27:11 +0100 Subject: [PATCH 12/68] sketch of neural RO decoder --- kraken/lib/segmentation.py | 79 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 6f6459ac5..784cb4ceb 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -817,6 +817,85 @@ def is_in_region(line, region) -> bool: return region.contains(l_obj) +def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]], + im_size: Tuple[int, int], + model): + """ + Given a list of baselines and regions, calculates the correct reading order + and applies it to the input. + + Args: + lines: List of tuples containing the baseline and its polygonization. + model: torch Module for + + Returns: + A reordered input. + """ + # construct all possible pairs + h, w = im_size + features = [] + for i in lines: + for j in lines: + if i == j and len(children) != 1: + continue + line_coords_i = np.array(i) / (w, h) + line_center_i = np.mean(line_coords_i, axis=0) + line_coords_j = np.array(j) / (w, h) + line_center_j = np.mean(line_coords_j, axis=0) + features.append(torch.cat((cl_i, + torch.tensor(line_center_i, dtype=torch.float), # lin + torch.tensor(line_coords_i[0, :], dtype=torch.float), + torch.tensor(line_coords_i[-1, :], dtype=torch.float), + cl_j, + torch.tensor(line_center_j, dtype=torch.float), # lin + torch.tensor(line_coords_j[0, :], dtype=torch.float), + torch.tensor(line_coords_j[-1, :], dtype=torch.float)))) + features = torch.cat(features) + output = model(features) + order = torch.zeros((len(lines), len(lines))) + idx = 0 + for i in enumerate(lines): + for j in enumerate(lines): + order[i, j] = output[idx] + idx += 1 + # decode order relation matrix + path = _greedy_order_decoder(order) + return ordered_lines + + +def _greedy_order_decoder(P): + """ + A greedy decoder of order-relation matrix. For each position in the + reading order we select the most probable one, then move to the next + position. Most probable for position: + z^{\star}_t = \argmax_{(s,\nu) \ni z^{\star}} + \prod_{(s',\nu') \in z^\star}{\tilde{P}(Y=1\mid s',s)} + \times \prod_{\substack{(s'',\nu'') \ni z^\star\\ + s'' \ne s}}{\tilde{P}(r=0\mid s'',s)}, 1\le t \le n + """ + A = P + torch.finfo(torch.float).eps + N = P.shape[0] + A = (A + (1-A).T)/2 + for i in range(A.shape[0]): + A[i,i] = torch.finfo(torch.float).eps + best_path = [] + #--- use log(p(R\mid s',s)) to shift multiplication to sum + lP = torch.log(A) + for i in range(N): + lP[i,i] = 0 + for t in range(N): + #print(lP) + #print("----------------------") + for i in range(N): + idx = torch.argmax(lP.sum(axis=1)) + if idx not in best_path: + best_path.append(idx) + lP[idx,:] = lP[:,idx] + lP[:,idx] = 0 + break + return best_path + + def scale_regions(regions: Sequence[Tuple[List[int], List[int]]], scale: Union[float, Tuple[float, float]]) -> Sequence[Tuple[List, List]]: """ From 3507a8ea83de1332adafbe062f1d7cb64e8e2e5d Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 6 Mar 2023 23:19:51 +0100 Subject: [PATCH 13/68] Use spearman footrule distance as evaluation metric for RO training --- kraken/lib/dataset/__init__.py | 2 +- kraken/lib/dataset/ro.py | 107 +++++++++++++++++++++++++++++- kraken/lib/ro/model.py | 115 ++++++++++++++++++++++++++------- kraken/lib/segmentation.py | 15 ++--- kraken/lib/train.py | 10 +-- 5 files changed, 206 insertions(+), 43 deletions(-) diff --git a/kraken/lib/dataset/__init__.py b/kraken/lib/dataset/__init__.py index 5388cfe6b..c6710d24c 100644 --- a/kraken/lib/dataset/__init__.py +++ b/kraken/lib/dataset/__init__.py @@ -17,5 +17,5 @@ """ from .recognition import ArrowIPCRecognitionDataset, PolygonGTDataset, GroundTruthDataset # NOQA from .segmentation import BaselineSet # NOQA -from .ro import ROSet #NOQA +from .ro import PairWiseROSet, PageWiseROSet #NOQA from .utils import ImageInputTransforms, collate_sequences, global_align, compute_confusions # NOQA diff --git a/kraken/lib/dataset/ro.py b/kraken/lib/dataset/ro.py index 88bb32b73..8bf7d6b72 100644 --- a/kraken/lib/dataset/ro.py +++ b/kraken/lib/dataset/ro.py @@ -43,9 +43,11 @@ logger = logging.getLogger(__name__) -class ROSet(Dataset): +class PairWiseROSet(Dataset): """ Dataset for training a reading order determination model. + + Returns random pairs of lines from the same page. """ def __init__(self, files: Sequence[Union[PathLike, str]] = None, mode: Optional[Literal['alto', 'page', 'xml']] = 'path', @@ -144,3 +146,106 @@ def get_feature_dim(self): def __len__(self): return self._num_pairs + + +class PageWiseROSet(Dataset): + """ + Dataset for training a reading order determination model. + + Returns all lines from the same page. + """ + def __init__(self, files: Sequence[Union[PathLike, str]] = None, + mode: Optional[Literal['alto', 'page', 'xml']] = 'path', + level: Literal['regions', 'baselines'] = 'baselines', + ro_id: Optional[str] = None, + class_mapping: Optional[Dict[str, int]] = None): + """ + Samples pairs lines/regions from XML files for training a reading order + model . + + Args: + mode: Either alto, page, xml, None. In alto, page, and xml + mode the baseline paths and image data is retrieved from an + ALTO/PageXML file. In `None` mode data is iteratively added + through the `add` method. + ro_id: ID of the reading order to sample from. Defaults to + `line_implicit`/`region_implicit`. + """ + super().__init__() + + self.failed_samples = [] + if class_mapping: + self.class_mapping = class_mapping + self.num_classes = len(class_mapping) + 1 + else: + self.num_classes = 1 + self.class_mapping = {} + + self.data = [] + + if mode in ['alto', 'page', 'xml']: + for file in files: + try: + doc = XMLPage(file, filetype=mode) + for tag in doc.tags: + if tag not in self.class_mapping: + self.class_mapping[tag] = self.num_classes + self.num_classes += 1 + except KrakenInputException as e: + files.pop(file) + logger.warning(e) + continue + for file in files: + try: + doc = XMLPage(file, filetype=mode) + if level == 'baselines': + if not ro_id: + ro_id = 'line_implicit' + order = doc.get_sorted_lines(ro_id) + elif level == 'regions': + if not ro_id: + ro_id = 'region_implicit' + order = doc.get_sorted_regions(ro_id) + else: + raise ValueError(f'Invalid RO type {level}') + # traverse RO and substitute features. + h,w = Image.open(doc.imagename).size + sorted_lines = [] + for line in order: + line_coords = np.array(line['baseline']) / (w, h) + line_center = np.mean(line_coords, axis=0) + cl = torch.zeros(self.num_classes, dtype=torch.float) + # if class is not in class mapping default to None class (idx 0) + cl[self.class_mapping.get(line['tags']['type'], 0)] = 1 + line_data = {'type': line['tags']['type'], + 'features': torch.cat((cl, # one hot encoded line type + torch.tensor(line_center, dtype=torch.float), # line center + torch.tensor(line_coords[0, :], dtype=torch.float), # start_point coord + torch.tensor(line_coords[-1, :], dtype=torch.float), # end point coord) + )) + } + sorted_lines.append(line_data) + self.data.append(sorted_lines) + except KrakenInputException as e: + logger.warning(e) + continue + else: + raise Exception('invalid dataset mode') + + def __getitem__(self, idx): + xs = [] + ys = [] + for i in range(len(self.data[idx])): + for j in range(len(self.data[idx])): + if i == j and len(self.data[idx]) != 1: + continue + xs.append(torch.cat((self.data[idx][i]['features'], + self.data[idx][j]['features']))) + ys.append(torch.tensor(0 if i >= j else 1, dtype=torch.float)) + return {'sample': torch.stack(xs), 'target': torch.stack(ys), 'num_lines': len(self.data[idx])} + + def get_feature_dim(self): + return 2 * self.num_classes + 12 + + def __len__(self): + return len(self.data) diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index c878f4a83..8cbcabf58 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -26,13 +26,17 @@ import pytorch_lightning as pl from os import PathLike +from torch.optim import lr_scheduler from dataclasses import dataclass, field from torch.nn import Module from typing import Dict, Optional, Sequence, Union, Any, Literal, List +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor + from kraken.lib import vgsl, default_specs, layers -from kraken.lib.dataset import ROSet +from kraken.lib.dataset import PairWiseROSet, PageWiseROSet from kraken.lib.train import _configure_optimizer_and_lr_scheduler +from kraken.lib.segmentation import _greedy_order_decoder from kraken.lib.ro.layers import MLP from torch.utils.data import DataLoader, random_split, Subset @@ -56,6 +60,10 @@ def save_model(self, filename): self.ptl_module.save_checkpoint(filename) +def spearman_footrule_distance(s, t): + return (s - t).abs().sum() / (0.5 * (len(s) ** 2 - (len(s) % 2))) + + class ROModel(pl.LightningModule): def __init__(self, hyper_params: Dict[str, Any] = None, @@ -92,16 +100,16 @@ def __init__(self, np.random.shuffle(training_data) training_data = training_data[:int(partition*len(training_data))] evaluation_data = training_data[int(partition*len(training_data)):] - train_set = ROSet(training_data, - mode=format_type, - level=level, - ro_id=reading_order) + train_set = PairWiseROSet(training_data, + mode=format_type, + level=level, + ro_id=reading_order) self.train_set = Subset(train_set, range(len(train_set))) - val_set = ROSet(evaluation_data, - mode=format_type, - class_mapping=train_set.class_mapping, - level=level, - ro_id=reading_order) + val_set = PageWiseROSet(evaluation_data, + mode=format_type, + class_mapping=train_set.class_mapping, + level=level, + ro_id=reading_order) self.val_set = Subset(val_set, range(len(val_set))) if len(self.train_set) == 0 or len(self.val_set) == 0: @@ -134,20 +142,34 @@ def forward(self, x): return self.ro_net(x) def validation_step(self, batch, batch_idx): - x, y = batch['sample'], batch['target'] - yhat = self.ro_net(x) - loss = self.criterion(yhat.squeeze(), y) - self.log('val_metric', loss) - return loss + xs, ys, num_lines = batch['sample'], batch['target'], batch['num_lines'] + yhat = self.ro_net(xs).squeeze() + order = torch.zeros((num_lines, num_lines)) + idx = 0 + for i in range(num_lines): + for j in range(num_lines): + if i != j: + order[i, j] = yhat[idx] + idx += 1 + path = _greedy_order_decoder(order) + spearman_dist = spearman_footrule_distance(torch.tensor(range(num_lines)), path) + self.log('val_spearman', spearman_dist) + loss = self.criterion(yhat, ys.squeeze()) + self.log('val_loss', loss) + return {'val_spearman': spearman_dist, 'val_loss': loss} def validation_epoch_end(self, outputs): - val_metric = np.mean([x.cpu() for x in outputs]) + val_metric = np.mean([x['val_spearman'].cpu() for x in outputs]) + val_loss = np.mean([x['val_loss'].cpu() for x in outputs]) + if val_metric < self.best_metric: logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {val_metric} ({self.current_epoch})') self.best_epoch = self.current_epoch self.best_metric = val_metric - logger.info(f'validation run: val_metric {val_metric}') - self.log('val_metric', val_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True) + logger.info(f'validation run: val_spearman {val_metric} val_loss {val_loss}') + self.log('val_spearman', val_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_metric', val_metric, on_step=False, on_epoch=True, prog_bar=False, logger=True) + self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) def training_step(self, batch, batch_idx): x, y = batch['sample'], batch['target'] @@ -156,12 +178,6 @@ def training_step(self, batch, batch_idx): self.log('loss', loss) return loss - def configure_optimizers(self): - return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, - self.ro_net.parameters(), - len_train_set=len(self.train_set), - loss_tracking_mode='min') - def train_dataloader(self): return DataLoader(self.train_set, batch_size=self.hyper_params['batch_size'], @@ -170,9 +186,58 @@ def train_dataloader(self): def val_dataloader(self): return DataLoader(self.val_set, - batch_size=self.hyper_params['batch_size'], + batch_size=1, num_workers=self.num_workers, pin_memory=True) def save_checkpoint(self, filename): self.trainer.save_checkpoint(filename) + + def configure_callbacks(self): + callbacks = [] + if self.hparams.hyper_params['quit'] == 'early': + callbacks.append(EarlyStopping(monitor='val_metric', + mode='min', + patience=self.hparams.hyper_params['lag'], + stopping_threshold=0.0)) + if self.hparams.hyper_params['pl_logger']: + callbacks.append(LearningRateMonitor(logging_interval='step')) + return callbacks + + # configuration of optimizers and learning rate schedulers + # -------------------------------------------------------- + # + # All schedulers are created internally with a frequency of step to enable + # batch-wise learning rate warmup. In lr_scheduler_step() calls to the + # scheduler are then only performed at the end of the epoch. + def configure_optimizers(self): + return _configure_optimizer_and_lr_scheduler(self.hparams.hyper_params, + self.ro_net.parameters(), + len_train_set=len(self.train_set), + loss_tracking_mode='min') + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu=False, using_native_amp=False, + using_lbfgs=False): + # update params + optimizer.step(closure=optimizer_closure) + + # linear warmup between 0 and the initial learning rate `lrate` in `warmup` + # steps. + if self.hparams.hyper_params['warmup'] and self.trainer.global_step < self.hparams.hyper_params['warmup']: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params['warmup']) + for pg in optimizer.param_groups: + pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] + + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: + # step OneCycleLR each batch if not in warmup phase + if isinstance(scheduler, lr_scheduler.OneCycleLR): + scheduler.step() + # step every other scheduler epoch-wise + elif self.trainer.is_last_batch: + if metric is None: + scheduler.step() + else: + scheduler.step(metric) + diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 784cb4ceb..96aeb5f26 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -16,6 +16,7 @@ Processing for baseline segmenter output """ import PIL +import torch import logging import numpy as np import shapely.geometry as geom @@ -33,7 +34,7 @@ from skimage import draw, filters from skimage.graph import MCP_Connect -from skimage.filters import apply_hysteresis_threshold, sobel +from skimage.filters import sobel from skimage.measure import approximate_polygon, subdivide_polygon, regionprops, label from skimage.morphology import skeletonize from skimage.transform import PiecewiseAffineTransform, SimilarityTransform, AffineTransform, warp @@ -50,7 +51,6 @@ logger = logging.getLogger('kraken') __all__ = ['reading_order', - 'denoising_hysteresis_thresh', 'vectorize_lines', 'calculate_polygonal_environment', 'polygonal_reading_order', @@ -131,11 +131,6 @@ def _visit(k): return L -def denoising_hysteresis_thresh(im, low, high, sigma): - im = gaussian_filter(im, sigma) - return apply_hysteresis_threshold(im, low, high) - - def moore_neighborhood(current, backtrack): operations = np.array([[-1, 0], [-1, 1], [0, 1], [1, 1], [1, 0], [1, -1], [0, -1], [-1, -1]]) @@ -879,13 +874,11 @@ def _greedy_order_decoder(P): for i in range(A.shape[0]): A[i,i] = torch.finfo(torch.float).eps best_path = [] - #--- use log(p(R\mid s',s)) to shift multiplication to sum + # use log(p(R\mid s',s)) to shift multiplication to sum lP = torch.log(A) for i in range(N): lP[i,i] = 0 for t in range(N): - #print(lP) - #print("----------------------") for i in range(N): idx = torch.argmax(lP.sum(axis=1)) if idx not in best_path: @@ -893,7 +886,7 @@ def _greedy_order_decoder(P): lP[idx,:] = lP[:,idx] lP[:,idx] = 0 break - return best_path + return torch.tensor(best_path) def scale_regions(regions: Sequence[Tuple[List[int], List[int]]], diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 5f4af6e17..253075253 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -650,13 +650,13 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # linear warmup between 0 and the initial learning rate `lrate` in `warmup` # steps. - if self.hparams.hyper_params.warmup and self.trainer.global_step < self.hparams.hyper_params.warmup: - lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params.warmup) + if self.hparams.hyper_params['warmup'] and self.trainer.global_step < self.hparams.hyper_params['warmup']: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.hparams.hyper_params['warmup']) for pg in optimizer.param_groups: - pg["lr"] = lr_scale * self.hparams.hyper_params.lrate + pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] def lr_scheduler_step(self, scheduler, optimizer_idx, metric): - if not self.hparams.hyper_params.warmup or self.trainer.global_step >= self.hparams.hyper_params.warmup: + if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: # step OneCycleLR each batch if not in warmup phase if isinstance(scheduler, lr_scheduler.OneCycleLR): scheduler.step() @@ -889,7 +889,7 @@ def on_validation_epoch_end(self): self.log('val_mean_acc', mean_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_mean_iu', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_freq_iu', freq_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log('val_metric', mean_iu, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.val_px_accuracy.reset() self.val_mean_accuracy.reset() From 72caf7b404127c485fcf569421c62e8102121689 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 6 Mar 2023 23:39:13 +0100 Subject: [PATCH 14/68] use original implementation hidden size --- kraken/lib/ro/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index 8cbcabf58..1187a242b 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -128,7 +128,7 @@ def __init__(self, self.best_metric = torch.inf logger.info(f'Creating new RO model') - self.ro_net = torch.jit.script(MLP(train_set.get_feature_dim(), 128)) + self.ro_net = torch.jit.script(MLP(train_set.get_feature_dim(), train_set.get_feature_dim() * 2)) if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug('Setting multiprocessing tensor sharing strategy to file_system') From 583d36fe2f6eb862cbeaf6b1152857f0416e8b10 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Wed, 8 Mar 2023 16:41:50 +0100 Subject: [PATCH 15/68] compute loss with logits --- kraken/lib/ro/layers.py | 4 +--- kraken/lib/ro/model.py | 13 +++++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/kraken/lib/ro/layers.py b/kraken/lib/ro/layers.py index c86cb0711..93c284f9a 100644 --- a/kraken/lib/ro/layers.py +++ b/kraken/lib/ro/layers.py @@ -18,10 +18,8 @@ def __init__(self, feature_size: int, hidden_size: int): self.fc1 = nn.Linear(feature_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, 1) - self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.fc1(x) x = self.relu(x) - x = self.fc2(x) - return self.sigmoid(x) + return self.fc2(x) diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index 1187a242b..6dec0817f 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -120,7 +120,7 @@ def __init__(self, f'{len(self.val_set)} lines') self.output = output - self.criterion = torch.nn.BCELoss() + self.criterion = torch.nn.BCEWithLogitsLoss() self.num_workers = num_workers @@ -139,11 +139,12 @@ def __init__(self, self.save_hyperparameters() def forward(self, x): - return self.ro_net(x) + return F.sigmoid(self.ro_net(x)) def validation_step(self, batch, batch_idx): xs, ys, num_lines = batch['sample'], batch['target'], batch['num_lines'] - yhat = self.ro_net(xs).squeeze() + logits = self.ro_net(xs).squeeze() + yhat = F.sigmoid(logits) order = torch.zeros((num_lines, num_lines)) idx = 0 for i in range(num_lines): @@ -154,7 +155,7 @@ def validation_step(self, batch, batch_idx): path = _greedy_order_decoder(order) spearman_dist = spearman_footrule_distance(torch.tensor(range(num_lines)), path) self.log('val_spearman', spearman_dist) - loss = self.criterion(yhat, ys.squeeze()) + loss = self.criterion(probits, ys.squeeze()) self.log('val_loss', loss) return {'val_spearman': spearman_dist, 'val_loss': loss} @@ -173,8 +174,8 @@ def validation_epoch_end(self, outputs): def training_step(self, batch, batch_idx): x, y = batch['sample'], batch['target'] - yhat = self.ro_net(x) - loss = self.criterion(yhat.squeeze(), y) + logits = self.ro_net(x) + loss = self.criterion(logits.squeeze(), y) self.log('loss', loss) return loss From 845335c5ee630d70ccf4c842d5fdfb22c5241313 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Wed, 8 Mar 2023 16:42:54 +0100 Subject: [PATCH 16/68] logits not probits --- kraken/lib/ro/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index 6dec0817f..3dff4c828 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -155,7 +155,7 @@ def validation_step(self, batch, batch_idx): path = _greedy_order_decoder(order) spearman_dist = spearman_footrule_distance(torch.tensor(range(num_lines)), path) self.log('val_spearman', spearman_dist) - loss = self.criterion(probits, ys.squeeze()) + loss = self.criterion(logits, ys.squeeze()) self.log('val_loss', loss) return {'val_spearman': spearman_dist, 'val_loss': loss} From 7af9a721ffe4ad03b1f286648bb58936f8a1c324 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Thu, 9 Mar 2023 12:32:02 +0100 Subject: [PATCH 17/68] s/h,w/w,h/g Fix normalization factors for RO datasets --- kraken/lib/dataset/ro.py | 4 ++-- kraken/lib/segmentation.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kraken/lib/dataset/ro.py b/kraken/lib/dataset/ro.py index 8bf7d6b72..b7c6d681d 100644 --- a/kraken/lib/dataset/ro.py +++ b/kraken/lib/dataset/ro.py @@ -105,7 +105,7 @@ def __init__(self, files: Sequence[Union[PathLike, str]] = None, else: raise ValueError(f'Invalid RO type {level}') # traverse RO and substitute features. - h,w = Image.open(doc.imagename).size + w, h = Image.open(doc.imagename).size sorted_lines = [] for line in order: line_coords = np.array(line['baseline']) / (w, h) @@ -209,7 +209,7 @@ def __init__(self, files: Sequence[Union[PathLike, str]] = None, else: raise ValueError(f'Invalid RO type {level}') # traverse RO and substitute features. - h,w = Image.open(doc.imagename).size + w, h = Image.open(doc.imagename).size sorted_lines = [] for line in order: line_coords = np.array(line['baseline']) / (w, h) diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 96aeb5f26..c5c6e5e4b 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -821,7 +821,7 @@ def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple Args: lines: List of tuples containing the baseline and its polygonization. - model: torch Module for + model: torch Module for Returns: A reordered input. From c51e28d56b43364826b52e6181e3198d1e00b3a7 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Thu, 23 Mar 2023 11:28:25 +0100 Subject: [PATCH 18/68] some small new parser fixes --- kraken/lib/xml.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index ebfd5c877..16f253e57 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -480,7 +480,7 @@ def _parse_pointstype(coords: str) -> Sequence[Tuple[float, float]]: class XMLPage(object): - type: Literal['baselines', 'bbox'] == 'baselines' + type: Literal['baselines', 'bbox'] = 'baselines' base_dir: Optional[Literal['L', 'R']] = None imagename: PathLike = None _orders: Dict[str, Dict[str, Any]] = None @@ -945,7 +945,7 @@ def get_lines_by_tag(self, key, value): return {k: v for k, v in self._lines.items() if v['tags'].get(key) == value} def get_lines_by_split(self, split: Literal['train', 'validation', 'test']): - return {k: v for k, v in self._lines.items() if v['tags'].get(key) == split} + return {k: v for k, v in self._lines.items() if v['tags'].get('split') == split} @property def tags(self): From 898a1edfb44f528078a599f4f5a5895b2c2dfaac Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Thu, 23 Mar 2023 11:30:05 +0100 Subject: [PATCH 19/68] more small fixes in lib/segmentation.py --- kraken/lib/segmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index c5c6e5e4b..3f3bee7e7 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -831,7 +831,7 @@ def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple features = [] for i in lines: for j in lines: - if i == j and len(children) != 1: + if i == j and len(lines) != 1: continue line_coords_i = np.array(i) / (w, h) line_center_i = np.mean(line_coords_i, axis=0) @@ -855,7 +855,7 @@ def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple idx += 1 # decode order relation matrix path = _greedy_order_decoder(order) - return ordered_lines + return path def _greedy_order_decoder(P): From 7be2aa02b1f1ef4333ababb1b4110ddc10591ed5 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 3 Apr 2023 13:20:42 +0200 Subject: [PATCH 20/68] decoder work --- kraken/lib/segmentation.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 3f3bee7e7..5e1f0b96b 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -51,6 +51,7 @@ logger = logging.getLogger('kraken') __all__ = ['reading_order', + 'neural_reading_order', 'vectorize_lines', 'calculate_polygonal_environment', 'polygonal_reading_order', @@ -833,9 +834,14 @@ def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple for j in lines: if i == j and len(lines) != 1: continue - line_coords_i = np.array(i) / (w, h) + num_classes = len(model.class_mapping) + 1 + cl_i = torch.zeros(num_classes, dtype=torch.float) + cl_j = torch.zeros(num_classes, dtype=torch.float) + cl_i[model.class_mapping.get(i['tags']['type'], 0)] = 1 + cl_j[model.class_mapping.get(j['tags']['type'], 0)] = 1 + line_coords_i = np.array(i['baseline']) / (w, h) line_center_i = np.mean(line_coords_i, axis=0) - line_coords_j = np.array(j) / (w, h) + line_coords_j = np.array(j['baseline']) / (w, h) line_center_j = np.mean(line_coords_j, axis=0) features.append(torch.cat((cl_i, torch.tensor(line_center_i, dtype=torch.float), # lin @@ -853,7 +859,7 @@ def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple for j in enumerate(lines): order[i, j] = output[idx] idx += 1 - # decode order relation matrix + # decode order relation matrix path = _greedy_order_decoder(order) return path @@ -863,28 +869,30 @@ def _greedy_order_decoder(P): A greedy decoder of order-relation matrix. For each position in the reading order we select the most probable one, then move to the next position. Most probable for position: - z^{\star}_t = \argmax_{(s,\nu) \ni z^{\star}} - \prod_{(s',\nu') \in z^\star}{\tilde{P}(Y=1\mid s',s)} - \times \prod_{\substack{(s'',\nu'') \ni z^\star\\ - s'' \ne s}}{\tilde{P}(r=0\mid s'',s)}, 1\le t \le n + + .. math:: + z^{\\star}_t = \\argmax_{(s,\\nu) \\ni z^{\\star}} + \\prod_{(s',\\nu') \\in z^\\star}{\\tilde{P}(Y=1\\mid s',s)} + \\times \\prod_{\\substack{(s'',\\nu'') \\ni z^\\star\\ + s'' \\ne s}}{\\tilde{P}(r=0\\mid s'',s)}, 1\\le t \\le n """ A = P + torch.finfo(torch.float).eps N = P.shape[0] A = (A + (1-A).T)/2 for i in range(A.shape[0]): - A[i,i] = torch.finfo(torch.float).eps + A[i, i] = torch.finfo(torch.float).eps best_path = [] # use log(p(R\mid s',s)) to shift multiplication to sum lP = torch.log(A) for i in range(N): - lP[i,i] = 0 + lP[i, i] = 0 for t in range(N): for i in range(N): idx = torch.argmax(lP.sum(axis=1)) if idx not in best_path: best_path.append(idx) - lP[idx,:] = lP[:,idx] - lP[:,idx] = 0 + lP[idx, :] = lP[:, idx] + lP[:, idx] = 0 break return torch.tensor(best_path) From 1c283a99fc57a06a63c1598fb216ed1abebe2c00 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 3 Apr 2023 13:20:59 +0200 Subject: [PATCH 21/68] add cls mapping to model params --- kraken/lib/ro/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index 3dff4c828..6bfc4f07f 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -105,6 +105,7 @@ def __init__(self, level=level, ro_id=reading_order) self.train_set = Subset(train_set, range(len(train_set))) + self.class_mapping = train_set.class_mapping val_set = PageWiseROSet(evaluation_data, mode=format_type, class_mapping=train_set.class_mapping, From 15820b4aab86d8aeec32b4af89dd8d0b0795f9f4 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 4 Apr 2023 13:06:29 +0200 Subject: [PATCH 22/68] coreml serialization of RO model --- kraken/lib/ro/layers.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/kraken/lib/ro/layers.py b/kraken/lib/ro/layers.py index 93c284f9a..fb63aad46 100644 --- a/kraken/lib/ro/layers.py +++ b/kraken/lib/ro/layers.py @@ -23,3 +23,36 @@ def forward(self, x): x = self.fc1(x) x = self.relu(x) return self.fc2(x) + + def deserialize(self, name: str, spec) -> None: + """ + Sets the weights of an initialized module from a CoreML protobuf spec. + """ + # extract 1st linear projection parameters + lin = [x for x in spec.neuralNetwork.layers if x.name == '{}_mlp_lin_1'.format(name)][0].innerProduct + weights = torch.Tensor(lin.weights.floatValue).resize_as_(self.fc1.weight.data) + bias = torch.Tensor(lin.bias.floatValue) + self.fc1.weight = torch.nn.Parameter(weights) + self.fc1.bias = torch.nn.Parameter(bias) + # extract 2nd linear projection parameters + lin = [x for x in spec.neuralNetwork.layers if x.name == '{}_mlp_lin_2'.format(name)][0].innerProduct + weights = torch.Tensor(lin.weights.floatValue).resize_as_(self.fc2.weight.data) + bias = torch.Tensor(lin.bias.floatValue) + self.fc2.weight = torch.nn.Parameter(weights) + self.fc2.bias = torch.nn.Parameter(bias) + + def serialize(self, name: str, input: str, builder): + """ + Serializes the module using a NeuralNetworkBuilder. + """ + builder.add_inner_product(f'{name}_mlp_lin_1', self.fc1.weight.data.numpy(), + self.fc1.bias.data.numpy(), + self.feature_size, self.hidden_size, + has_bias=True, input_name=input, output_name=f'{name}_mlp_lin_1') + builder.add_activation(f'{name}_mlp_lin_1_relu', 'RELU', f'{name}_mlp_lin_1', f'{name}_mlp_lin_1_relu') + builder.add_inner_product(f'{name}_mlp_lin_1', self.fc2.weight.data.numpy(), + self.fc2.bias.data.numpy(), + self.hidden_size, 1, + has_bias=True, input_name=f'{name}_mlp_lin_1_relu', output_name=f'{name}_mlp_lin_2') + return name + From b04492831ade4e4139edd16775881878c2e332b7 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 4 Apr 2023 13:28:48 +0200 Subject: [PATCH 23/68] some linter cleanup --- kraken/lib/xml.py | 49 +++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index 16f253e57..7b78082b3 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -26,9 +26,9 @@ from PIL import Image from typing import Union, Dict, Any, Sequence, Tuple, Literal, Optional, List -from os import PathLike from collections import defaultdict from kraken.lib.segmentation import calculate_polygonal_environment +from kraken.lib.exceptions import KrakenInputException logger = logging.getLogger(__name__) @@ -478,6 +478,7 @@ def _parse_pointstype(coords: str) -> Sequence[Tuple[float, float]]: data['tags'] = False return data + class XMLPage(object): type: Literal['baselines', 'bbox'] = 'baselines' @@ -545,10 +546,6 @@ def _parse_alto(self): y_min = int(float(ps.get('VPOS'))) width = int(float(ps.get('WIDTH'))) height = int(float(ps.get('HEIGHT'))) - page_boundary = [(x_min, y_min), - (x_min, y_min + height), - (x_min + width, y_min + height), - (x_min + width, y_min)] # parse tagrefs cls_map = {} @@ -643,7 +640,6 @@ def _parse_alto(self): self._regions = region_data - if len(self._tag_set) > 1: self.has_tags = True else: @@ -839,26 +835,28 @@ def _parse_page(self): # parse explicit reading orders if they exist ro_el = doc.find('.//{*}ReadingOrder') if ro_el is not None: - reading_orders = ro_el.getchildren() - # UnorderedGroup at top-level => treated as multiple reading orders - if len(reading_orders) == 1 and reading_orders[0].tag.endswith('UnorderedGroup'): + reading_orders = ro_el.getchildren() + # UnorderedGroup at top-level => treated as multiple reading orders + if len(reading_orders) == 1 and reading_orders[0].tag.endswith('UnorderedGroup'): reading_orders = reading_orders.getchildren() - def _parse_group(el): - _ro = [] - if el.tag.endswith('UnorderedGroup'): - _ro = [_parse_group(x) for x in el.iterchildren()] - is_total = False - elif el.tag.endswith('OrderedGroup'): - _ro.extend(_parse_group(x) for x in el.iterchildren()) - else: - return el.get('regionRef') - return _ro - - for ro in reading_orders: - is_total = True - self._orders[ro.get('id')] = {'order': _parse_group(ro), - 'is_total': is_total, - 'description': ro.get('caption') if ro.get('caption') else ''} + + def _parse_group(el): + + _ro = [] + if el.tag.endswith('UnorderedGroup'): + _ro = [_parse_group(x) for x in el.iterchildren()] + is_total = False + elif el.tag.endswith('OrderedGroup'): + _ro.extend(_parse_group(x) for x in el.iterchildren()) + else: + return el.get('regionRef') + return _ro + + for ro in reading_orders: + is_total = True + self._orders[ro.get('id')] = {'order': _parse_group(ro), + 'is_total': is_total, + 'description': ro.get('caption') if ro.get('caption') else ''} if len(self._tag_set) > 1: self.has_tags = True @@ -885,6 +883,7 @@ def get_sorted_lines(self, ro='line_implicit'): """ if ro not in self.reading_orders: raise ValueError(f'Unknown reading order {ro}') + def _traverse_ro(el): _ro = [] if isinstance(el, list): From 1326d09394ec7b20e795526aea6400e4f976a3c0 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 4 Apr 2023 13:29:17 +0200 Subject: [PATCH 24/68] remove unused failed_sample_threshold --- kraken/ketos/pretrain.py | 1 - 1 file changed, 1 deletion(-) diff --git a/kraken/ketos/pretrain.py b/kraken/ketos/pretrain.py index 1c4bc3d0e..ea5a14e19 100644 --- a/kraken/ketos/pretrain.py +++ b/kraken/ketos/pretrain.py @@ -279,7 +279,6 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, min_epochs=hyper_params['min_epochs'], enable_progress_bar=True if not ctx.meta['verbose'] else False, deterministic=ctx.meta['deterministic'], - failed_sample_threshold=failed_sample_threshold, **val_check_interval) trainer.fit(model, datamodule=data_module) From 98faf22582e65322557b99b6a66cc32b722c22ab Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 4 Apr 2023 13:30:16 +0200 Subject: [PATCH 25/68] more import fixes --- kraken/lib/dataset/ro.py | 2 +- kraken/lib/ro/layers.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/kraken/lib/dataset/ro.py b/kraken/lib/dataset/ro.py index b7c6d681d..31dee62ec 100644 --- a/kraken/lib/dataset/ro.py +++ b/kraken/lib/dataset/ro.py @@ -36,7 +36,7 @@ from kraken.lib.exceptions import KrakenInputException -__all__ = ['BaselineSet'] +__all__ = ['PairWiseROSet', 'PageWiseROSet'] import logging diff --git a/kraken/lib/ro/layers.py b/kraken/lib/ro/layers.py index fb63aad46..7c15c0868 100644 --- a/kraken/lib/ro/layers.py +++ b/kraken/lib/ro/layers.py @@ -1,6 +1,7 @@ """ Layers for VGSL models """ +import torch from torch import nn # all tensors are ordered NCHW, the "feature" dimension is C, so the output of From f36310bc1f24e1a7c2be1b93dcc9662aa966f34a Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 4 Apr 2023 14:26:48 +0200 Subject: [PATCH 26/68] syntax fix xml tests --- tests/test_xml.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_xml.py b/tests/test_xml.py index cee436ad8..611959d2e 100644 --- a/tests/test_xml.py +++ b/tests/test_xml.py @@ -27,7 +27,6 @@ def test_page_parsing(self): doc = xml.XMLPage(self.page_doc, filetype='page') self.assertEqual(len(doc.baselines), 97) self.assertEqual(len([item for x in doc.regions.values() for item in x]), 4) - self.assertEqual( def test_alto_parsing(self): """ From a7f1140c5af14084a3c363a21838da410cff856e Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 4 Apr 2023 14:39:17 +0200 Subject: [PATCH 27/68] xml parsing tests files --- tests/resources/bsb00084914_00007.xml | 1074 +++++++++++++++++++++++++ tests/resources/cPAS-2000.xml | 410 ++++++++++ 2 files changed, 1484 insertions(+) create mode 100644 tests/resources/bsb00084914_00007.xml create mode 100644 tests/resources/cPAS-2000.xml diff --git a/tests/resources/bsb00084914_00007.xml b/tests/resources/bsb00084914_00007.xml new file mode 100644 index 000000000..311751ad1 --- /dev/null +++ b/tests/resources/bsb00084914_00007.xml @@ -0,0 +1,1074 @@ + + + + pixel + + bsb00084914_00007.jpg + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/resources/cPAS-2000.xml b/tests/resources/cPAS-2000.xml new file mode 100644 index 000000000..d9f844121 --- /dev/null +++ b/tests/resources/cPAS-2000.xml @@ -0,0 +1,410 @@ + + + + TRP + 2018-12-24T11:28:19+07:00 + 2019-02-05T09:16:48Z + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 778c20b0b7d7919aef5f1767169ab7d9a2dd7eaf Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 16 Apr 2023 13:17:33 +0200 Subject: [PATCH 28/68] lightning 2.0 changes to RO code --- kraken/lib/progress.py | 2 +- kraken/lib/ro/model.py | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/kraken/lib/progress.py b/kraken/lib/progress.py index 0284fa865..344864d01 100644 --- a/kraken/lib/progress.py +++ b/kraken/lib/progress.py @@ -128,7 +128,7 @@ def _init_progress(self, trainer): def _get_train_description(self, current_epoch: int) -> str: return f"stage {current_epoch}/" \ - f"{self.trainer.max_epochs if self.trainer.model.hparams['quit'] == 'fixed' else '∞'}" + f"{self.trainer.max_epochs if self.trainer.model.hparams.hyper_params['quit'] == 'fixed' else '∞'}" @dataclass class RichProgressBarTheme: diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index 6bfc4f07f..b066bf565 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -137,6 +137,9 @@ def __init__(self, self.nn = DummyVGSLModel(ptl_module=self) + self.val_losses = [] + self.val_spearman = [] + self.save_hyperparameters() def forward(self, x): @@ -157,12 +160,14 @@ def validation_step(self, batch, batch_idx): spearman_dist = spearman_footrule_distance(torch.tensor(range(num_lines)), path) self.log('val_spearman', spearman_dist) loss = self.criterion(logits, ys.squeeze()) - self.log('val_loss', loss) - return {'val_spearman': spearman_dist, 'val_loss': loss} + self.val_losses.append(loss.cpu()) + self.val_spearman.append(spearman_dist.cpu()) - def validation_epoch_end(self, outputs): - val_metric = np.mean([x['val_spearman'].cpu() for x in outputs]) - val_loss = np.mean([x['val_loss'].cpu() for x in outputs]) + def on_validation_epoch_end(self): + val_metric = np.mean(self.val_spearman) + val_loss = np.mean(self.val_losses) + self.val_spearman.clear() + self.val_losses.clear() if val_metric < self.best_metric: logger.debug(f'Updating best metric from {self.best_metric} ({self.best_epoch}) to {val_metric} ({self.current_epoch})') @@ -218,9 +223,7 @@ def configure_optimizers(self): len_train_set=len(self.train_set), loss_tracking_mode='min') - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, - optimizer_closure, on_tpu=False, using_native_amp=False, - using_lbfgs=False): + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # update params optimizer.step(closure=optimizer_closure) @@ -231,7 +234,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.hparams.hyper_params['lrate'] - def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + def lr_scheduler_step(self, scheduler, metric): if not self.hparams.hyper_params['warmup'] or self.trainer.global_step >= self.hparams.hyper_params['warmup']: # step OneCycleLR each batch if not in warmup phase if isinstance(scheduler, lr_scheduler.OneCycleLR): From c733b85b333198213b354c494ab01b889c7ddc98 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 16 Apr 2023 15:42:16 +0200 Subject: [PATCH 29/68] Code for adding RO models to segmentation models --- kraken/ketos/ro.py | 32 ++++++++++++++++++++++++++++++++ kraken/lib/ro/layers.py | 29 ++++++++++++++++++++++------- kraken/lib/ro/model.py | 2 +- kraken/lib/vgsl.py | 21 ++++++++++++++++++++- 4 files changed, 75 insertions(+), 9 deletions(-) diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py index 27e4851be..2854ea7fd 100644 --- a/kraken/ketos/ro.py +++ b/kraken/ketos/ro.py @@ -258,3 +258,35 @@ def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag, logger.info('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format( output, model.best_epoch, model.best_metric)) shutil.copy(f'{output}_{model.best_epoch}.mlmodel', f'{output}_best.mlmodel') + + +@click.command('roadd') +@click.pass_context +@click.option('-o', '--output', show_default=True, type=click.Path(), default='combined_seg.mlmodel', help='Combined output model file') +@click.option('-r', '--ro-model', show_default=True, type=click.Path(exists=True, readable=True), help='Reading order model to load into segmentation model') +@click.option('-i', '--seg-model', show_default=True, type=click.Path(exists=True, readable=True), help='Segmentation model to load') +def rotrain(ctx, output, ro_model, seg_model): + """ + Combines a reading order model with a segmentation model. + """ + from kraken.lib import vgsl + from kraken.lib.ro import ROModel + from kraken.lib.train import KrakenTrainer + + message(f'Adding {ro_model} reading order model to {seg_model}.') + ro_net = ROModel.load_from_checkpoint(ro_model) + message('Line classes known to RO model:') + for k, v in ro_net.class_mapping.items(): + message(f' {k}\t{v}') + seg_net = vgsl.TorchVGSLModel.load_model(seg_model) + if seg_net.model_type != 'segmentation': + raise click.UsageError(f'Model {seg_model} is invalid {seg_net.model_type} model (expected `segmentation`).') + message('Line classes known to segmentation model:') + for k, v in seg_net.user_metadata['class_mapping']['baselines'].items(): + message(f' {k}\t{v}') + if ro_net.class_mapping.keys() != seg_net.user_metadata['class_mapping']['baselines'].keys(): + raise click.UsageError(f'Model {seg_model} and {ro_model} class mappings mismatch.') + + seg_net.aux_layers = {'ro_model': ro_net.ro_net} + message(f'Saving combined model to {output}') + seg_net.save_model(output) diff --git a/kraken/lib/ro/layers.py b/kraken/lib/ro/layers.py index 7c15c0868..87fffc078 100644 --- a/kraken/lib/ro/layers.py +++ b/kraken/lib/ro/layers.py @@ -4,6 +4,8 @@ import torch from torch import nn +from typing import Tuple + # all tensors are ordered NCHW, the "feature" dimension is C, so the output of # an LSTM will be put into C same as the filters of a CNN. @@ -19,24 +21,38 @@ def __init__(self, feature_size: int, hidden_size: int): self.fc1 = nn.Linear(feature_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, 1) + self.feature_size = feature_size + self.hidden_size = hidden_size def forward(self, x): x = self.fc1(x) x = self.relu(x) return self.fc2(x) + def get_shape(self, input: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]: + """ + Calculates the output shape from input 4D tuple NCHW. + """ + return input + + def get_spec(self, name) -> "VGSLBlock": + """ + Generates a VGSL spec block from the layer instance. + """ + return f'[1,0,0,1 RO{{{name}}}{self.feature_size},{self.hidden_size}]' + def deserialize(self, name: str, spec) -> None: """ Sets the weights of an initialized module from a CoreML protobuf spec. """ # extract 1st linear projection parameters - lin = [x for x in spec.neuralNetwork.layers if x.name == '{}_mlp_lin_1'.format(name)][0].innerProduct + lin = [x for x in spec.neuralNetwork.layers if x.name == '{}_mlp_lin_0'.format(name)][0].innerProduct weights = torch.Tensor(lin.weights.floatValue).resize_as_(self.fc1.weight.data) bias = torch.Tensor(lin.bias.floatValue) self.fc1.weight = torch.nn.Parameter(weights) self.fc1.bias = torch.nn.Parameter(bias) # extract 2nd linear projection parameters - lin = [x for x in spec.neuralNetwork.layers if x.name == '{}_mlp_lin_2'.format(name)][0].innerProduct + lin = [x for x in spec.neuralNetwork.layers if x.name == '{}_mlp_lin_1'.format(name)][0].innerProduct weights = torch.Tensor(lin.weights.floatValue).resize_as_(self.fc2.weight.data) bias = torch.Tensor(lin.bias.floatValue) self.fc2.weight = torch.nn.Parameter(weights) @@ -46,14 +62,13 @@ def serialize(self, name: str, input: str, builder): """ Serializes the module using a NeuralNetworkBuilder. """ - builder.add_inner_product(f'{name}_mlp_lin_1', self.fc1.weight.data.numpy(), + builder.add_inner_product(f'{name}_mlp_lin_0', self.fc1.weight.data.numpy(), self.fc1.bias.data.numpy(), self.feature_size, self.hidden_size, - has_bias=True, input_name=input, output_name=f'{name}_mlp_lin_1') - builder.add_activation(f'{name}_mlp_lin_1_relu', 'RELU', f'{name}_mlp_lin_1', f'{name}_mlp_lin_1_relu') + has_bias=True, input_name=input, output_name=f'{name}_mlp_lin_0') + builder.add_activation(f'{name}_mlp_lin_0_relu', 'RELU', f'{name}_mlp_lin_0', f'{name}_mlp_lin_0_relu') builder.add_inner_product(f'{name}_mlp_lin_1', self.fc2.weight.data.numpy(), self.fc2.bias.data.numpy(), self.hidden_size, 1, - has_bias=True, input_name=f'{name}_mlp_lin_1_relu', output_name=f'{name}_mlp_lin_2') + has_bias=True, input_name=f'{name}_mlp_lin_0_relu', output_name=f'{name}_mlp_lin_1') return name - diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index b066bf565..4fa957144 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -129,7 +129,7 @@ def __init__(self, self.best_metric = torch.inf logger.info(f'Creating new RO model') - self.ro_net = torch.jit.script(MLP(train_set.get_feature_dim(), train_set.get_feature_dim() * 2)) + self.ro_net = MLP(train_set.get_feature_dim(), train_set.get_feature_dim() * 2) if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug('Setting multiprocessing tensor sharing strategy to file_system') diff --git a/kraken/lib/vgsl.py b/kraken/lib/vgsl.py index 9f745ceb4..ad762da77 100644 --- a/kraken/lib/vgsl.py +++ b/kraken/lib/vgsl.py @@ -137,7 +137,7 @@ def __init__(self, spec: str) -> None: self.build_dropout, self.build_maxpool, self.build_conv, self.build_output, self.build_reshape, self.build_wav2vec2, self.build_groupnorm, self.build_series, - self.build_parallel] + self.build_parallel, self.build_ro] self.codec = None # type: Optional[PytorchCodec] self.criterion = None # type: Any self.nn = layers.MultiParamSequential() @@ -577,6 +577,25 @@ def build_wav2vec2(self, f'{mask_prob}, negative samples {num_negatives}') return fn.get_shape(input), [VGSLBlock(blocks[idx], m.group('type'), m.group('name'), self.idx)], fn + def build_ro(self, + input: Tuple[int, int, int, int], + blocks: List[str], + idx: int) -> Union[Tuple[None, None, None], Tuple[Tuple[int, int, int, int], str, Callable]]: + """ + Builds a RO determination layer. + """ + pattern = re.compile(r'(?PRO)(?P{\w+})(?P\d+),(?P\d+)') + m = pattern.match(blocks[idx]) + if not m: + return None, None, None + feature_size = int(m.group('feature_size')) + hidden_size = int(m.group('hidden_size')) + from kraken.lib import ro + fn = ro.layers.MLP(feature_size, hidden_size) + self.idx += 1 + logger.debug(f'{self.idx}\t\tro\tfeatures {feature_size}, hidden_size {hidden_size}') + return fn.get_shape(input), [VGSLBlock(blocks[idx], m.group('type'), m.group('name'), self.idx)], fn + def build_conv(self, input: Tuple[int, int, int, int], blocks: List[str], From d4947afa40c0dfab016919023371a83c9bcd3191 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 16 Apr 2023 18:41:09 +0200 Subject: [PATCH 30/68] Working inference --- kraken/blla.py | 29 +++++++++++++++++++++++++---- kraken/ketos/__init__.py | 3 ++- kraken/ketos/ro.py | 3 ++- kraken/lib/ro/layers.py | 1 + kraken/lib/segmentation.py | 33 +++++++++++++++++++++------------ 5 files changed, 51 insertions(+), 18 deletions(-) diff --git a/kraken/blla.py b/kraken/blla.py index 84d7fafca..19b3fa0ba 100644 --- a/kraken/blla.py +++ b/kraken/blla.py @@ -29,6 +29,7 @@ import torch.nn.functional as F import torchvision.transforms as tf +from functools import partial from typing import Optional, Dict, Callable, Union, List, Any, Tuple from scipy.ndimage import gaussian_filter @@ -38,6 +39,7 @@ from kraken.lib.util import is_bitonal, get_im_str from kraken.lib.exceptions import KrakenInputException, KrakenInvalidModelException from kraken.lib.segmentation import (polygonal_reading_order, + neural_reading_order, vectorize_lines, vectorize_regions, scale_polygonal_lines, calculate_polygonal_environment, @@ -73,8 +75,6 @@ def compute_segmentation_map(im: PIL.Image.Image, Raises: KrakenInputException: When given an invalid mask. """ - im_str = get_im_str(im) - logger.info(f'Segmenting {im_str}') if model.input[1] == 1 and model.one_channel_mode == '1' and not is_bitonal(im): logger.warning('Running binary model on non-binary input image ' @@ -250,7 +250,7 @@ def vec_lines(heatmap: torch.Tensor, def segment(im: PIL.Image.Image, text_direction: str = 'horizontal-lr', mask: Optional[np.ndarray] = None, - reading_order_fn: Callable = polygonal_reading_order, + reading_order_fn: Optional[Callable] = None, model: Union[List[vgsl.TorchVGSLModel], vgsl.TorchVGSLModel] = None, device: str = 'cpu', raise_on_error: bool = False, @@ -271,7 +271,15 @@ def segment(im: PIL.Image.Image, detection. reading_order_fn: Function to determine the reading order. Has to accept a list of tuples (baselines, polygon) and a - text direction (`lr` or `rl`). + text direction (`lr` or `rl`). If None is given it + defaults to either + :func:`kraken.lib.segmentation.polygonal_reading_order` + or + :func:`kraken.lib.segmentation.neural_reading_order` + depending on the presence of a neural reading order + net in the segmentation model. If multiple + segmentation models are given and more than one + contains an RO net the first one will be used. model: One or more TorchVGSLModel containing a segmentation model. If none is given a default model will be loaded. device: The target device to run the neural network on. @@ -313,6 +321,18 @@ def segment(im: PIL.Image.Image, if isinstance(model, vgsl.TorchVGSLModel): model = [model] + # determine which reading order function to use + if not reading_order_fn: + reading_order_fn = polygonal_reading_order + for x in model: + if 'ro_model' in x.aux_layers: + logger.info(f'Using reading order model found in segmentation model {x}.') + reading_order_fn = partial(neural_reading_order, + model=x.aux_layers['ro_model'], + im_size=im.size, + class_mapping=x.user_metadata['ro_class_mapping']) + break + for nn in model: if nn.model_type != 'segmentation': raise KrakenInvalidModelException(f'Invalid model type {nn.model_type} for {nn}') @@ -340,6 +360,7 @@ def segment(im: PIL.Image.Image, # convert back to net scale suppl_obj = scale_regions(suppl_obj, 1/rets['scale']) line_regs = scale_regions(line_regs, 1/rets['scale']) + lines = vec_lines(**rets, regions=line_regs, reading_order_fn=reading_order_fn, diff --git a/kraken/ketos/__init__.py b/kraken/ketos/__init__.py index 4b7087dc4..a8ddfe9a2 100644 --- a/kraken/ketos/__init__.py +++ b/kraken/ketos/__init__.py @@ -34,7 +34,7 @@ from .repo import publish from .segmentation import segtrain, segtest from .transcription import extract, transcription -from .ro import rotrain +from .ro import rotrain, roadd APP_NAME = 'kraken' @@ -78,6 +78,7 @@ def cli(ctx, verbose, seed, deterministic): cli.add_command(segtest) cli.add_command(publish) cli.add_command(rotrain) +cli.add_command(roadd) # deprecated commands cli.add_command(line_generator) diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py index 2854ea7fd..1dcc0856f 100644 --- a/kraken/ketos/ro.py +++ b/kraken/ketos/ro.py @@ -265,7 +265,7 @@ def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag, @click.option('-o', '--output', show_default=True, type=click.Path(), default='combined_seg.mlmodel', help='Combined output model file') @click.option('-r', '--ro-model', show_default=True, type=click.Path(exists=True, readable=True), help='Reading order model to load into segmentation model') @click.option('-i', '--seg-model', show_default=True, type=click.Path(exists=True, readable=True), help='Segmentation model to load') -def rotrain(ctx, output, ro_model, seg_model): +def roadd(ctx, output, ro_model, seg_model): """ Combines a reading order model with a segmentation model. """ @@ -288,5 +288,6 @@ def rotrain(ctx, output, ro_model, seg_model): raise click.UsageError(f'Model {seg_model} and {ro_model} class mappings mismatch.') seg_net.aux_layers = {'ro_model': ro_net.ro_net} + seg_net.user_metadata['ro_class_mapping'] = ro_net.class_mapping message(f'Saving combined model to {output}') seg_net.save_model(output) diff --git a/kraken/lib/ro/layers.py b/kraken/lib/ro/layers.py index 87fffc078..a18f3de1e 100644 --- a/kraken/lib/ro/layers.py +++ b/kraken/lib/ro/layers.py @@ -23,6 +23,7 @@ def __init__(self, feature_size: int, hidden_size: int): self.fc2 = nn.Linear(hidden_size, 1) self.feature_size = feature_size self.hidden_size = hidden_size + self.class_mapping = None def forward(self, x): x = self.fc1(x) diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 5e1f0b96b..104dd00b9 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -20,6 +20,7 @@ import logging import numpy as np import shapely.geometry as geom +import torch.nn.functional as F from collections import defaultdict @@ -814,8 +815,11 @@ def is_in_region(line, region) -> bool: def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]], - im_size: Tuple[int, int], - model): + text_direction: str = 'lr', + regions: Optional[Sequence[List[Tuple[int, int]]]] = None, + im_size: Tuple[int, int] = None, + model: 'TorchVGSLModel' = None, + class_mapping: Dict[str, int] = None) -> Sequence[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]: """ Given a list of baselines and regions, calculates the correct reading order and applies it to the input. @@ -834,14 +838,14 @@ def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple for j in lines: if i == j and len(lines) != 1: continue - num_classes = len(model.class_mapping) + 1 + num_classes = len(class_mapping) + 1 cl_i = torch.zeros(num_classes, dtype=torch.float) cl_j = torch.zeros(num_classes, dtype=torch.float) - cl_i[model.class_mapping.get(i['tags']['type'], 0)] = 1 - cl_j[model.class_mapping.get(j['tags']['type'], 0)] = 1 - line_coords_i = np.array(i['baseline']) / (w, h) + cl_i[class_mapping.get(i[0], 0)] = 1 + cl_j[class_mapping.get(j[0], 0)] = 1 + line_coords_i = np.array(i[1]) / (w, h) line_center_i = np.mean(line_coords_i, axis=0) - line_coords_j = np.array(j['baseline']) / (w, h) + line_coords_j = np.array(j[1]) / (w, h) line_center_j = np.mean(line_coords_j, axis=0) features.append(torch.cat((cl_i, torch.tensor(line_center_i, dtype=torch.float), # lin @@ -851,17 +855,22 @@ def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple torch.tensor(line_center_j, dtype=torch.float), # lin torch.tensor(line_coords_j[0, :], dtype=torch.float), torch.tensor(line_coords_j[-1, :], dtype=torch.float)))) - features = torch.cat(features) - output = model(features) + features = torch.stack(features) + output = F.sigmoid(model(features)) + order = torch.zeros((len(lines), len(lines))) idx = 0 - for i in enumerate(lines): - for j in enumerate(lines): + for i in range(len(lines)): + for j in range(len(lines)): + if i == j and len(lines) != 1: + continue order[i, j] = output[idx] idx += 1 # decode order relation matrix path = _greedy_order_decoder(order) - return path + # reorder lines + lines = [lines[idx] for idx in path] + return lines def _greedy_order_decoder(P): From 9aca246f875f48f7bae8641de3a0d82cf469ef50 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 21 Apr 2023 11:40:25 +0200 Subject: [PATCH 31/68] Segmentation data class in blla/pageseg --- kraken/blla.py | 133 +++++++++++++++++++++---------------- kraken/kraken.py | 26 ++++---- kraken/lib/segmentation.py | 47 ++++++++----- kraken/lib/xml.py | 2 +- kraken/pageseg.py | 11 +-- kraken/serialization.py | 12 ++-- kraken/transcribe.py | 2 - 7 files changed, 131 insertions(+), 102 deletions(-) diff --git a/kraken/blla.py b/kraken/blla.py index 19b3fa0ba..767ec6867 100644 --- a/kraken/blla.py +++ b/kraken/blla.py @@ -29,8 +29,7 @@ import torch.nn.functional as F import torchvision.transforms as tf -from functools import partial -from typing import Optional, Dict, Callable, Union, List, Any, Tuple +from typing import Optional, Dict, Callable, Union, List, Any, Tuple, Literal from scipy.ndimage import gaussian_filter from skimage.filters import sobel @@ -38,7 +37,8 @@ from kraken.lib import vgsl, dataset from kraken.lib.util import is_bitonal, get_im_str from kraken.lib.exceptions import KrakenInputException, KrakenInvalidModelException -from kraken.lib.segmentation import (polygonal_reading_order, +from kraken.lib.segmentation import (Segmentation, + polygonal_reading_order, neural_reading_order, vectorize_lines, vectorize_regions, scale_polygonal_lines, @@ -163,7 +163,6 @@ def vec_lines(heatmap: torch.Tensor, cls_map: Dict[str, Dict[str, int]], scale: float, text_direction: str = 'horizontal-lr', - reading_order_fn: Callable = polygonal_reading_order, regions: List[np.ndarray] = None, scal_im: np.ndarray = None, suppl_obj: List[np.ndarray] = None, @@ -181,7 +180,6 @@ def vec_lines(heatmap: torch.Tensor, scale: Scaling factor between heatmap and unscaled input image. text_direction: Text directions used as hints in the reading order algorithm. - reading_order_fn: Reading order calculation function. regions: Regions to be used as boundaries during polygonization and atomic blocks during reading order determination for lines contained within. @@ -242,15 +240,13 @@ def vec_lines(heatmap: torch.Tensor, logger.debug('Scaling vectorized lines') sc = scale_polygonal_lines([x[1:] for x in lines], scale) lines = list(zip([x[0] for x in lines], [x[0] for x in sc], [x[1] for x in sc])) - logger.debug('Reordering baselines') - lines = reading_order_fn(lines=lines, regions=regions, text_direction=text_direction[-2:]) return [{'tags': {'type': bl_type}, 'baseline': bl, 'boundary': pl} for bl_type, bl, pl in lines] def segment(im: PIL.Image.Image, - text_direction: str = 'horizontal-lr', + text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] = 'horizontal-lr', mask: Optional[np.ndarray] = None, - reading_order_fn: Optional[Callable] = None, + reading_order_fn: Callable = polygonal_reading_order, model: Union[List[vgsl.TorchVGSLModel], vgsl.TorchVGSLModel] = None, device: str = 'cpu', raise_on_error: bool = False, @@ -271,15 +267,7 @@ def segment(im: PIL.Image.Image, detection. reading_order_fn: Function to determine the reading order. Has to accept a list of tuples (baselines, polygon) and a - text direction (`lr` or `rl`). If None is given it - defaults to either - :func:`kraken.lib.segmentation.polygonal_reading_order` - or - :func:`kraken.lib.segmentation.neural_reading_order` - depending on the presence of a neural reading order - net in the segmentation model. If multiple - segmentation models are given and more than one - contains an RO net the first one will be used. + text direction (`lr` or `rl`). model: One or more TorchVGSLModel containing a segmentation model. If none is given a default model will be loaded. device: The target device to run the neural network on. @@ -288,31 +276,36 @@ def segment(im: PIL.Image.Image, autocast: Runs the model with automatic mixed precision Returns: - A dictionary containing the text direction and under the key 'lines' a - list of reading order sorted baselines (polylines) and their respective - polygonal boundaries. The last and first point of each boundary polygon - are connected. + A :class:`kraken.lib.blla.Segmentation` class containing reading order + sorted baselines (polylines) and their respective polygonal boundaries. + The format of the line and region records is shown below. The last and + first point of each boundary polygon are connected. .. code-block:: :force: - {'text_direction': '$dir', - 'type': 'baseline', - 'lines': [ - {'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]}, - {'baseline': [[x0, ...]], 'boundary': [[x0, ...]]} - ] - 'regions': [ - {'region': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'type': 'image'}, - {'region': [[x0, ...]], 'type': 'text'} - ] - } + 'lines': [ + {'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]}, + {'baseline': [[x0, ...]], 'boundary': [[x0, ...]]} + ] + 'regions': [ + {'region': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'type': 'image'}, + {'region': [[x0, ...]], 'type': 'text'} + ] Raises: KrakenInvalidModelException: if the given model is not a valid segmentation model. KrakenInputException: if the mask is not bitonal or does not match the image size. + + Notes: + Multi-model operation is most useful for combining one or more region + detection models and one text line model. Detected lines from all + models are simply combined without any merging or duplicate detection + so the chance of the same line appearing multiple times in the output + are high. In addition, neural reading order determination is disabled + when more than one model outputs lines. """ if model is None: logger.info('No segmentation model given. Loading default model.') @@ -321,18 +314,6 @@ def segment(im: PIL.Image.Image, if isinstance(model, vgsl.TorchVGSLModel): model = [model] - # determine which reading order function to use - if not reading_order_fn: - reading_order_fn = polygonal_reading_order - for x in model: - if 'ro_model' in x.aux_layers: - logger.info(f'Using reading order model found in segmentation model {x}.') - reading_order_fn = partial(neural_reading_order, - model=x.aux_layers['ro_model'], - im_size=im.size, - class_mapping=x.user_metadata['ro_class_mapping']) - break - for nn in model: if nn.model_type != 'segmentation': raise KrakenInvalidModelException(f'Invalid model type {nn.model_type} for {nn}') @@ -342,6 +323,12 @@ def segment(im: PIL.Image.Image, im_str = get_im_str(im) logger.info(f'Segmenting {im_str}') + lines = [] + order = None + regions = {} + multi_lines = False + # flag to indicate that multiple models produced line output -> disable + # neural reading order for net in model: if 'topline' in net.user_metadata: loc = {None: 'center', @@ -349,11 +336,12 @@ def segment(im: PIL.Image.Image, False: 'bottom'}[net.user_metadata['topline']] logger.debug(f'Baseline location: {loc}') rets = compute_segmentation_map(im, mask, net, device, autocast=autocast) - regions = vec_regions(**rets) + _regions = vec_regions(**rets) + # flatten regions for line ordering/fetch bounding regions line_regs = [] suppl_obj = [] - for cls, regs in regions.items(): + for cls, regs in _regions.items(): line_regs.extend(regs) if rets['bounding_regions'] is not None and cls in rets['bounding_regions']: suppl_obj.extend(regs) @@ -361,21 +349,48 @@ def segment(im: PIL.Image.Image, suppl_obj = scale_regions(suppl_obj, 1/rets['scale']) line_regs = scale_regions(line_regs, 1/rets['scale']) - lines = vec_lines(**rets, - regions=line_regs, - reading_order_fn=reading_order_fn, - text_direction=text_direction, - suppl_obj=suppl_obj, - topline=net.user_metadata['topline'] if 'topline' in net.user_metadata else False, - raise_on_error=raise_on_error) + _lines = vec_lines(**rets, + regions=line_regs, + text_direction=text_direction, + suppl_obj=suppl_obj, + topline=net.user_metadata['topline'] if 'topline' in net.user_metadata else False, + raise_on_error=raise_on_error) + + if 'ro_model' in net.aux_layers: + logger.info(f'Using reading order model found in segmentation model {net}.') + _order = neural_reading_order(lines=_lines, + regions=regions, + text_direction=text_direction[-2:], + model=net.aux_layers['ro_model'], + im_size=im.size, + class_mapping=net.user_metadata['ro_class_mapping']) + else: + _order = None + + if _lines and lines or multi_lines: + multi_lines = True + order = None + logger.warning('Multiple models produced line output. This is ' + 'likely unintended. Suppressing neural reading ' + 'order.') + else: + order = _order + + lines.extend(_lines) + + # reorder lines + logger.debug(f'Reordering baselines with main RO function {reading_order_fn}.') + basic_lo = reading_order_fn(lines=lines, regions=regions, text_direction=text_direction[-2:]) + lines = [lines[idx] for idx in basic_lo] if len(rets['cls_map']['baselines']) > 1: script_detection = True else: script_detection = False - return {'text_direction': text_direction, - 'type': 'baselines', - 'lines': lines, - 'regions': regions, - 'script_detection': script_detection} + return Segmentation(text_direction=text_direction, + type='baselines', + lines=lines, + regions=regions, + script_detection=script_detection, + line_orders=[order]) diff --git a/kraken/kraken.py b/kraken/kraken.py index 6788cdeef..a45e47f75 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -21,6 +21,7 @@ import os import warnings import logging +import dataclasses import pkg_resources from typing import Dict, Union, List, cast, Any, IO, Callable @@ -57,15 +58,9 @@ def message(msg: str, **styles) -> None: def get_input_parser(type_str: str) -> Callable[[str], Dict[str, Any]]: - if type_str == 'alto': - from kraken.lib.xml import parse_alto - return parse_alto - elif type_str == 'page': - from kraken.lib.xml import parse_page - return parse_page - elif type_str == 'xml': - from kraken.lib.xml import parse_xml - return parse_xml + if type_str in ['alto', 'page', 'xml']: + from kraken.lib.xml import XMLPage + return XMLPage elif type_str == 'image': return Image.open @@ -78,7 +73,7 @@ def binarizer(threshold, zoom, escale, border, perc, range, low, high, input, ou ctx = click.get_current_context() if ctx.meta['first_process']: if ctx.meta['input_format_type'] != 'image': - input = get_input_parser(ctx.meta['input_format_type'])(input)['image'] + input = get_input_parser(ctx.meta['input_format_type'])(input).imagename ctx.meta['first_process'] = False else: raise click.UsageError('Binarization has to be the initial process.') @@ -131,7 +126,7 @@ def segmenter(legacy, model, text_direction, scale, maxcolseps, black_colseps, if ctx.meta['first_process']: if ctx.meta['input_format_type'] != 'image': - input = get_input_parser(ctx.meta['input_format_type'])(input)['image'] + input = get_input_parser(ctx.meta['input_format_type'])(input).imagename ctx.meta['first_process'] = False if 'base_image' not in ctx.meta: @@ -179,7 +174,7 @@ def segmenter(legacy, model, text_direction, scale, maxcolseps, black_colseps, else: with click.open_file(output, 'w') as fp: fp = cast(IO[Any], fp) - json.dump(res, fp) + json.dump(dataclasses.asdict(res), fp) message('\u2713', fg='green') @@ -203,7 +198,12 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, if doc['base_dir'] and bidi_reordering is True: message(f'Setting base text direction for BiDi reordering to {doc["base_dir"]} (from XML input file)') bidi_reordering = doc['base_dir'] - bounds = doc + bounds = {'text_direction': 'horizontal-lr', + 'tags': True, + 'lines': doc.get_sorted_lines(), + 'regions': doc.get_sorted_regions(), + 'type': 'baselines', + 'image': doc.imagename} try: im = Image.open(ctx.meta['base_image']) except IOError as e: diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 104dd00b9..20a14e973 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -22,6 +22,7 @@ import shapely.geometry as geom import torch.nn.functional as F +from dataclasses import dataclass from collections import defaultdict from PIL import Image @@ -40,7 +41,7 @@ from skimage.morphology import skeletonize from skimage.transform import PiecewiseAffineTransform, SimilarityTransform, AffineTransform, warp -from typing import List, Tuple, Union, Dict, Any, Sequence, Optional +from typing import List, Tuple, Union, Dict, Any, Sequence, Optional, Literal from kraken.lib import default_specs from kraken.lib.exceptions import KrakenInputException @@ -59,10 +60,21 @@ 'scale_polygonal_lines', 'scale_regions', 'compute_polygon_section', - 'extract_polygons'] + 'extract_polygons', + 'Segmentation'] -def reading_order(lines: Sequence[Tuple[slice, slice]], text_direction: str = 'lr') -> np.ndarray: +@dataclass +class Segmentation: + type: Literal['baselines', 'bbox'] + text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] + script_detection: bool + lines: List + regions: Dict[str, List] + line_orders: List[List[int]] + + +def reading_order(lines: Sequence[Tuple[slice, slice]], text_direction: Literal['lr', 'rl'] = 'lr') -> np.ndarray: """Given the list of lines (a list of 2D slices), computes the partial reading order. The output is a binary 2D array such that order[i,j] is true if line i comes before line j @@ -737,9 +749,9 @@ def calculate_polygonal_environment(im: PIL.Image.Image = None, return polygons -def polygonal_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]], - text_direction: str = 'lr', - regions: Optional[Sequence[List[Tuple[int, int]]]] = None) -> Sequence[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]: +def polygonal_reading_order(lines: Sequence[Dict], + text_direction: Literal['lr', 'rl'] = 'lr', + regions: Optional[Sequence[List[Tuple[int, int]]]] = None) -> Sequence[int]: """ Given a list of baselines and regions, calculates the correct reading order and applies it to the input. @@ -752,8 +764,10 @@ def polygonal_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tu Can be 'lr' or 'rl' Returns: - A reordered input. + The indices of the ordered input. """ + lines = [(line['tags']['type'], line['baseline'], line['boundary']) for line in lines] + bounds = [] if regions is not None: r = [geom.Polygon(reg) for reg in regions] @@ -789,13 +803,13 @@ def polygonal_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tu lsort = topsort(order) sidz = sorted(indizes.keys()) lsort = [sidz[i] for i in lsort] - ordered_lines = [] + ordered_idxs = [] for i in lsort: if indizes[i][0] == 'line': - ordered_lines.append(indizes[i][1]) + ordered_idxs.append(i) else: - ordered_lines.extend(lines[x] for x in intra_region_order[indizes[i][1]]) - return ordered_lines + ordered_idxs.extend(intra_region_order[indizes[i][1]]) + return ordered_idxs def is_in_region(line, region) -> bool: @@ -814,12 +828,12 @@ def is_in_region(line, region) -> bool: return region.contains(l_obj) -def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]], +def neural_reading_order(lines: Sequence[Dict], text_direction: str = 'lr', regions: Optional[Sequence[List[Tuple[int, int]]]] = None, im_size: Tuple[int, int] = None, model: 'TorchVGSLModel' = None, - class_mapping: Dict[str, int] = None) -> Sequence[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]: + class_mapping: Dict[str, int] = None) -> Sequence[int]: """ Given a list of baselines and regions, calculates the correct reading order and applies it to the input. @@ -829,8 +843,9 @@ def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple model: torch Module for Returns: - A reordered input. + The indices of the ordered input. """ + lines = [(line['tags']['type'], line['baseline'], line['boundary']) for line in lines] # construct all possible pairs h, w = im_size features = [] @@ -868,9 +883,7 @@ def neural_reading_order(lines: Sequence[Tuple[List[Tuple[int, int]], List[Tuple idx += 1 # decode order relation matrix path = _greedy_order_decoder(order) - # reorder lines - lines = [lines[idx] for idx in path] - return lines + return path def _greedy_order_decoder(P): diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index 7b78082b3..135c50560 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) -__all__ = ['parse_xml', 'parse_page', 'parse_alto', 'preparse_xml_data'] +__all__ = ['XMLPage', 'parse_xml', 'parse_page', 'parse_alto', 'preparse_xml_data'] # fallback mapping between PAGE region types and tags page_regions = {'TextRegion': 'text', diff --git a/kraken/pageseg.py b/kraken/pageseg.py index 3f52ba339..355af4f5d 100644 --- a/kraken/pageseg.py +++ b/kraken/pageseg.py @@ -29,7 +29,7 @@ from kraken.lib import morph, sl from kraken.lib.util import pil2array, is_bitonal, get_im_str from kraken.lib.exceptions import KrakenInputException -from kraken.lib.segmentation import reading_order, topsort +from kraken.lib.segmentation import reading_order, topsort, Segmentation __all__ = ['segment'] @@ -424,6 +424,9 @@ def segment(im, pad = (pad, pad) lines = [(max(x[0]-pad[0], 0), x[1], min(x[2]+pad[1], im.size[0]), x[3]) for x in lines] - return {'text_direction': text_direction, - 'boxes': rotate_lines(lines, 360-angle, offset).tolist(), - 'script_detection': False} + return Segmentation(text_direction=text_direction, + type='bbox', + regions=None, + line_orders=None, + lines=rotate_lines(lines, 360-angle, offset).tolist(), + script_detection=False) diff --git a/kraken/serialization.py b/kraken/serialization.py index c741923da..da644c948 100644 --- a/kraken/serialization.py +++ b/kraken/serialization.py @@ -25,7 +25,7 @@ from kraken.rpred import BaselineOCRRecord, BBoxOCRRecord, ocr_record from kraken.lib.util import make_printable -from kraken.lib.segmentation import is_in_region +from kraken.lib.segmentation import is_in_region, Segmentation from typing import Union, List, Tuple, Iterable, Optional, Sequence, Dict, Any, Literal @@ -246,7 +246,7 @@ def _load_template(name): return tmpl.render(page=page, metadata=metadata) -def serialize_segmentation(segresult: Dict[str, Any], +def serialize_segmentation(segresult: Segmentation, image_name: Union[PathLike, str] = None, image_size: Tuple[int, int] = (0, 0), template: Union[PathLike, str] = 'alto', @@ -266,18 +266,18 @@ def serialize_segmentation(segresult: Dict[str, Any], Returns: (str) rendered template. """ - if 'type' in segresult and segresult['type'] == 'baselines': - records = [BaselineOCRRecord('', (), (), bl) for bl in segresult['lines']] + if segresult.type == 'baselines': + records = [BaselineOCRRecord('', (), (), bl) for bl in segresult.lines] else: records = [] - for line in segresult['boxes']: + for line in segresult.lines: xmin, xmax = min(line[::2]), max(line[::2]) ymin, ymax = min(line[1::2]), max(line[1::2]) records.append(BBoxOCRRecord('', (), (), ((xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)))) return serialize(records, image_name=image_name, image_size=image_size, - regions=segresult['regions'] if 'regions' in segresult else None, + regions=segresult.regions, template=template, template_source=template_source, processing_steps=processing_steps) diff --git a/kraken/transcribe.py b/kraken/transcribe.py index e9a8067b8..5b39ee2f7 100644 --- a/kraken/transcribe.py +++ b/kraken/transcribe.py @@ -18,8 +18,6 @@ from kraken.lib.exceptions import KrakenInputException from kraken.lib.util import get_im_str -from typing import List - from jinja2 import Environment, PackageLoader from io import BytesIO From cbc67cd5925236419b992e091908f89226dc1d58 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 21 Apr 2023 12:02:33 +0200 Subject: [PATCH 32/68] Switch over rpred.* to Segmentation class --- kraken/rpred.py | 99 +++++++++++++++++++++++-------------------------- 1 file changed, 46 insertions(+), 53 deletions(-) diff --git a/kraken/rpred.py b/kraken/rpred.py index ae80ffa63..33375cc12 100644 --- a/kraken/rpred.py +++ b/kraken/rpred.py @@ -30,7 +30,7 @@ from kraken.lib.util import get_im_str, is_bitonal from kraken.lib.models import TorchSeqRecognizer -from kraken.lib.segmentation import extract_polygons, compute_polygon_section +from kraken.lib.segmentation import extract_polygons, compute_polygon_section, Segmentation from kraken.lib.exceptions import KrakenInputException from kraken.lib.dataset import ImageInputTransforms @@ -401,7 +401,7 @@ class mm_rpred(object): def __init__(self, nets: Dict[str, TorchSeqRecognizer], im: Image.Image, - bounds: dict, + bounds: Segmentation, pad: int = 16, bidi_reordering: Union[bool, str] = True, tags_ignore: Optional[List[str]] = None) -> Generator[ocr_record, None, None]: @@ -413,20 +413,18 @@ def __init__(self, these lines. Args: - nets (dict): A dict mapping tag values to TorchSegRecognizer - objects. Recommended to be an defaultdict. - im (PIL.Image.Image): Image to extract text from - bounds (dict): A dictionary containing a 'boxes' entry - with a list of lists of coordinates (script, (x0, y0, - x1, y1)) of a text line in the image and an entry - 'text_direction' containing - 'horizontal-lr/rl/vertical-lr/rl'. - pad (int): Extra blank padding to the left and right of text line - bidi_reordering (bool|str): Reorder classes in the ocr_record according to - the Unicode bidirectional algorithm for - correct display. Set to L|R to - override default text direction. - tags_ignore (list): List of tag values to ignore during recognition + nets: A dict mapping tag values to TorchSegRecognizer objects. + Recommended to be an defaultdict. + im: Image to extract text from + bounds: A Segmentation data class containing either bounding box or + baseline type segmentation. + pad: Extra blank padding to the left and right of text line + bidi_reordering: Reorder classes in the ocr_record according to the + Unicode bidirectional algorithm for correct + display. Set to L|R to override default text + direction. + tags_ignore: List of tag values to ignore during recognition + Yields: An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. @@ -445,36 +443,35 @@ def __init__(self, if not tags_ignore: tags_ignore = [] - if ('type' in bounds and bounds['type'] not in seg_types) or len(seg_types) > 1: + if bounds.type not in seg_types or len(seg_types) > 1: logger.warning(f'Recognizers with segmentation types {seg_types} will be ' - f'applied to segmentation of type {bounds["type"] if "type" in bounds else None}. ' + f'applied to segmentation of type {bounds.type}. ' f'This will likely result in severely degraded performace') one_channel_modes = set(recognizer.nn.one_channel_mode for recognizer in nets.values()) if '1' in one_channel_modes and len(one_channel_modes) > 1: raise KrakenInputException('Mixing binary and non-binary recognition models is not supported.') elif '1' in one_channel_modes and not is_bitonal(im): logger.warning('Running binary models on non-binary input image ' - '(mode {}). This will result in severely degraded ' - 'performance'.format(im.mode)) - if 'type' in bounds and bounds['type'] == 'baselines': + f'(mode {im.mode}). This will result in severely degraded ' + 'performance') + + self.len = len(bounds.lines) + self.line_iter = iter(bounds.lines) + + if bounds.type == 'baselines': valid_norm = False - self.len = len(bounds['lines']) - self.seg_key = 'lines' self.next_iter = self._recognize_baseline_line - self.line_iter = iter(bounds['lines']) tags = set() - for x in bounds['lines']: + for x in bounds.lines: tags.update(x['tags'].values()) else: valid_norm = True - self.len = len(bounds['boxes']) self.seg_key = 'boxes' self.next_iter = self._recognize_box_line - self.line_iter = iter(bounds['boxes']) - tags = set(x[0] for line in bounds['boxes'] for x in line) + tags = set(x[0] for line in bounds.lines for x in line) im_str = get_im_str(im) - logger.info('Running {} multi-script recognizers on {} with {} lines'.format(len(nets), im_str, self.len)) + logger.info(f'Running {len(nets)} multi-script recognizers on {im_str} with {self.len} lines') filtered_tags = [] miss = [] @@ -486,12 +483,12 @@ def __init__(self, tags = filtered_tags if miss: - raise KrakenInputException('Missing models for tags {}'.format(set(miss))) + raise KrakenInputException(f'Missing models for tags {set(miss)}') # build dictionary for line preprocessing self.ts = {} for tag in tags: - logger.debug('Loading line transforms for {}'.format(tag)) + logger.debug(f'Loading line transforms for {tag}') network = nets[tag] batch, channels, height, width = network.nn.input self.ts[tag] = ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm) @@ -554,7 +551,7 @@ def _recognize_box_line(self, line): conf = [] for _, start, end, c in preds: - if self.bounds['text_direction'].startswith('horizontal'): + if self.bounds.text_direction.startswith('horizontal'): xmin = coords[0] + self._scale_val(start, 0, self.box.size[0]) xmax = coords[0] + self._scale_val(end, 0, self.box.size[0]) pos.append([[xmin, coords[1]], [xmin, coords[3]], [xmax, coords[3]], [xmax, coords[1]]]) @@ -631,7 +628,7 @@ def _recognize_baseline_line(self, line): def __next__(self): bound = self.bounds - bound[self.seg_key] = [next(self.line_iter)] + setattr(bound, self.seg_key, [next(self.line_iter)]) return self.next_iter(bound) def __iter__(self): @@ -646,39 +643,35 @@ def _scale_val(self, val, min_val, max_val): def rpred(network: TorchSeqRecognizer, im: Image.Image, - bounds: dict, + bounds: Segmentation, pad: int = 16, bidi_reordering: Union[bool, str] = True) -> Generator[ocr_record, None, None]: """ Uses a TorchSeqRecognizer and a segmentation to recognize text Args: - network (kraken.lib.models.TorchSeqRecognizer): A TorchSegRecognizer - object - im (PIL.Image.Image): Image to extract text from - bounds (dict): A dictionary containing a 'boxes' entry with a list of - coordinates (x0, y0, x1, y1) of a text line in the image - and an entry 'text_direction' containing - 'horizontal-lr/rl/vertical-lr/rl'. - pad (int): Extra blank padding to the left and right of text line. - Auto-disabled when expected network inputs are incompatible - with padding. - bidi_reordering (bool|str): Reorder classes in the ocr_record according to - the Unicode bidirectional algorithm for correct - display. Set to L|R to change base text - direction. + network: A TorchSegRecognizer object + im: Image to extract text from + bounds: A Segmentation class instance containing either a baseline or bbox segmentation. + pad: Extra blank padding to the left and right of text line. + Auto-disabled when expected network inputs are incompatible with + padding. + bidi_reordering: Reorder classes in the ocr_record according to the + Unicode bidirectional algorithm for correct display. + Set to L|R to change base text direction. + Yields: An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. """ bounds = copy.deepcopy(bounds) - if 'boxes' in bounds: - boxes = bounds['boxes'] + if bounds.type == 'bbox': + boxes = bounds.lines rewrite_boxes = [] for box in boxes: rewrite_boxes.append([('default', box)]) - bounds['boxes'] = rewrite_boxes - bounds['script_detection'] = True + bounds.lines = rewrite_boxes + bounds.script_detection = True return mm_rpred(defaultdict(lambda: network), im, bounds, pad, bidi_reordering) @@ -693,4 +686,4 @@ def _resolve_tags_to_model(tags: Sequence[Dict[str, str]], return tag, model_map[tag] if default: return next(tags.values()), default - raise KrakenInputException('No model for tags {}'.format(tags)) + raise KrakenInputException(f'No model for tags {tags}') From 14f98c69715aaa5dcb33cf5f75a31b0aa2e3a404 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 24 Apr 2023 11:44:36 +0200 Subject: [PATCH 33/68] Use XMLPage in dataset/ro.py --- kraken/lib/dataset/ro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/lib/dataset/ro.py b/kraken/lib/dataset/ro.py index 31dee62ec..e6c6bf678 100644 --- a/kraken/lib/dataset/ro.py +++ b/kraken/lib/dataset/ro.py @@ -32,7 +32,7 @@ from torch.utils.data import Dataset from typing import Dict, List, Tuple, Sequence, Callable, Any, Union, Literal, Optional -from kraken.lib.xml import parse_alto, parse_page, parse_xml, XMLPage +from kraken.lib.xml import XMLPage from kraken.lib.exceptions import KrakenInputException From d4193e01e5bcd8216b5d99b80e1816a6dba207ec Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 24 Apr 2023 11:45:55 +0200 Subject: [PATCH 34/68] XMLPage in dataset/segmentation.py --- kraken/lib/dataset/segmentation.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/kraken/lib/dataset/segmentation.py b/kraken/lib/dataset/segmentation.py index 075507cbf..faa0536ce 100644 --- a/kraken/lib/dataset/segmentation.py +++ b/kraken/lib/dataset/segmentation.py @@ -33,7 +33,7 @@ from skimage.draw import polygon -from kraken.lib.xml import parse_alto, parse_page, parse_xml +from kraken.lib.xml import XMLPage from kraken.lib.exceptions import KrakenInputException @@ -103,32 +103,25 @@ def __init__(self, imgs: Sequence[Union[PathLike, str]] = None, self.valid_baselines = valid_baselines self.valid_regions = valid_regions if mode in ['alto', 'page', 'xml']: - if mode == 'alto': - fn = parse_alto - elif mode == 'page': - fn = parse_page - elif mode == 'xml': - fn = parse_xml im_paths = [] self.targets = [] for img in imgs: try: - data = fn(img) - im_paths.append(data['image']) + data = XMLPage(img) + im_paths.append(data.imagename) lines = defaultdict(list) - for line in data['lines']: + for line in data.get_sorted_lines(): if valid_baselines is None or set(line['tags'].values()).intersection(valid_baselines): tags = set(line['tags'].values()).intersection(valid_baselines) if valid_baselines else line['tags'].values() for tag in tags: lines[self.mbl_dict.get(tag, tag)].append(line['baseline']) self.class_stats['baselines'][self.mbl_dict.get(tag, tag)] += 1 regions = defaultdict(list) - for k, v in data['regions'].items(): + for k, v in data.regions.items(): if valid_regions is None or k in valid_regions: regions[self.mreg_dict.get(k, k)].extend(v) self.class_stats['regions'][self.mreg_dict.get(k, k)] += len(v) - data['regions'] = regions - self.targets.append({'baselines': lines, 'regions': data['regions']}) + self.targets.append({'baselines': lines, 'regions': regions}) except KrakenInputException as e: logger.warning(e) continue From 9b34c09f14f79054c37b4d0a04adb727fccf43e4 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 25 Apr 2023 12:48:22 +0200 Subject: [PATCH 35/68] extract_polygons with dataclass --- kraken/lib/segmentation.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 20a14e973..d821d9ce6 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -67,6 +67,7 @@ @dataclass class Segmentation: type: Literal['baselines', 'bbox'] + imagename: str text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] script_detection: bool lines: List @@ -1046,29 +1047,20 @@ def compute_polygon_section(baseline: Sequence[Tuple[int, int]], return tuple(o) -def extract_polygons(im: Image.Image, bounds: Dict[str, Any]) -> Image.Image: +def extract_polygons(im: Image.Image, bounds: Segmentation) -> Image.Image: """ Yields the subimages of image im defined in the list of bounding polygons with baselines preserving order. Args: im: Input image - bounds: A list of dicts in baseline:: - - {'type': 'baselines', - 'lines': [{'baseline': [[x_0, y_0], ... [x_n, y_n]], - 'boundary': [[x_0, y_0], ... [x_n, y_n]]}, - ....] - } - - or bounding box format:: - - {'boxes': [[x_0, y_0, x_1, y_1], ...], 'text_direction': 'horizontal-lr'} + bounds: A Segmentation class containing a boundig box or baseline + segmentation. Yields: The extracted subimage """ - if 'type' in bounds and bounds['type'] == 'baselines': + if bounds.type == 'baselines': # select proper interpolation scheme depending on shape if im.mode == '1': order = 0 @@ -1077,7 +1069,7 @@ def extract_polygons(im: Image.Image, bounds: Dict[str, Any]) -> Image.Image: order = 1 im = np.array(im) - for line in bounds['lines']: + for line in bounds.lines: if line['boundary'] is None: raise KrakenInputException('No boundary given for line') pl = np.array(line['boundary']) @@ -1166,11 +1158,11 @@ def extract_polygons(im: Image.Image, bounds: Dict[str, Any]) -> Image.Image: i = Image.fromarray(o.astype('uint8')) yield i.crop(i.getbbox()), line else: - if bounds['text_direction'].startswith('vertical'): + if bounds.text_direction.startswith('vertical'): angle = 90 else: angle = 0 - for box in bounds['boxes']: + for box in bounds.lines: if isinstance(box, tuple): box = list(box) if (box < [0, 0, 0, 0] or box[::2] >= [im.size[0], im.size[0]] or From 09591567e60f8fa0776748af33fbb52c11d8e9b1 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Thu, 11 May 2023 13:08:50 +0200 Subject: [PATCH 36/68] Add new container classes --- kraken/containers.py | 417 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 417 insertions(+) create mode 100644 kraken/containers.py diff --git a/kraken/containers.py b/kraken/containers.py new file mode 100644 index 000000000..5110ad659 --- /dev/null +++ b/kraken/containers.py @@ -0,0 +1,417 @@ + +import PIL.Image + +from typing import Literal, List, Dict, Sequence, Union, Optional, Tuple +from dataclasses import dataclass, asdict +from abc import ABC, abstractmethod + + +@dataclass +class BaselineLine: + """ + """ + id: str + baseline: List[Tuple[int, int]] + boundary: List[Tuple[int, int]] + text: Optional[str] = None + base_dir: Optional[Literal['L', 'R']] = None + type: str = 'baselines' + image: Optional[PIL.Image.Image] = None + + +@dataclass +class BBoxLine: + """ + """ + id: str + bbox: Tuple[Tuple[int, int], + Tuple[int, int], + Tuple[int, int], + Tuple[int, int]] + text: Optional[str] = None + base_dir: Optional[Literal['L', 'R']] = None + type: str = 'bbox' + image: Optional[PIL.Image.Image] = None + + +@dataclass +class Segmentation: + """ + + """ + type: Literal['baselines', 'bbox'] + imagename: str + text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] + script_detection: bool + lines: Sequence[Union[BaselineLine, BBoxLine]] + regions: Dict[str, List] + line_orders: List[List[int]] + + +class ocr_record(ABC): + """ + A record object containing the recognition result of a single line + """ + base_dir = None + + def __init__(self, + prediction: str, + cuts: Sequence[Union[Tuple[int, int], Tuple[Tuple[int, int], + Tuple[int, int], + Tuple[int, int], + Tuple[int, int]]]], + confidences: Sequence[float], + display_order: bool = True) -> None: + self._prediction = prediction + self._cuts = cuts + self._confidences = confidences + self._display_order = display_order + + @property + @abstractmethod + def type(self): + pass + + def __len__(self) -> int: + return len(self._prediction) + + def __str__(self) -> str: + return self._prediction + + @property + def prediction(self) -> str: + return self._prediction + + @property + def cuts(self) -> Sequence: + return self._cuts + + @property + def confidences(self) -> List[float]: + return self._confidences + + def __iter__(self): + self.idx = -1 + return self + + @abstractmethod + def __next__(self) -> Tuple[str, + Union[Sequence[Tuple[int, int]], + Tuple[Tuple[int, int], + Tuple[int, int], + Tuple[int, int], + Tuple[int, int]]], + float]: + pass + + @abstractmethod + def __getitem__(self, key: Union[int, slice]): + pass + + @abstractmethod + def display_order(self, base_dir) -> 'ocr_record': + pass + + @abstractmethod + def logical_order(self, base_dir) -> 'ocr_record': + pass + + +class BaselineOCRRecord(ocr_record, BaselineLine): + """ + A record object containing the recognition result of a single line in + baseline format. + + Attributes: + type: 'baselines' to indicate a baseline record + prediction: The text predicted by the network as one continuous string. + cuts: The absolute bounding polygons for each code point in prediction + as a list of tuples [(x0, y0), (x1, y2), ...]. + confidences: A list of floats indicating the confidence value of each + code point. + + Notes: + When slicing the record the behavior of the cuts is changed from + earlier versions of kraken. Instead of returning per-character bounding + polygons a single polygons section of the line bounding polygon + starting at the first and extending to the last code point emitted by + the network is returned. This aids numerical stability when computing + aggregated bounding polygons such as for words. Individual code point + bounding polygons are still accessible through the `cuts` attribute or + by iterating over the record code point by code point. + """ + type = 'baselines' + + def __init__(self, prediction: str, + cuts: Sequence[Tuple[int, int]], + confidences: Sequence[float], + line: BaselineLine, + base_dir: Optional[Literal['L', 'R']] = None, + display_order: bool = True) -> None: + if line.type != 'baselines': + raise TypeError('Invalid argument type (non-baseline line)') + BaselineLine.__init__(self, **asdict(line)) + self._line_base_dir = self.base_dir + self.base_dir = base_dir + ocr_record.__init__(self, prediction, cuts, confidences, display_order) + + def __repr__(self) -> str: + return f'pred: {self.prediction} baseline: {self.baseline} boundary: {self.boundary} confidences: {self.confidences}' + + def __next__(self) -> Tuple[str, int, float]: + if self.idx + 1 < len(self): + self.idx += 1 + return (self.prediction[self.idx], + compute_polygon_section(self.baseline, + self.line, + self.cuts[self.idx][0], + self.cuts[self.idx][1]), + self.confidences[self.idx]) + else: + raise StopIteration + + def _get_raw_item(self, key: int): + if key < 0: + key += len(self) + if key >= len(self): + raise IndexError('Index (%d) is out of range' % key) + return (self.prediction[key], + self._cuts[key], + self.confidences[key]) + + def __getitem__(self, key: Union[int, slice]): + if isinstance(key, slice): + recs = [self._get_raw_item(i) for i in range(*key.indices(len(self)))] + prediction = ''.join([x[0] for x in recs]) + flat_offsets = sum((tuple(x[1]) for x in recs), ()) + cut = compute_polygon_section(self.baseline, + self.line, + min(flat_offsets), + max(flat_offsets)) + confidence = np.mean([x[2] for x in recs]) + return (prediction, cut, confidence) + elif isinstance(key, int): + pred, cut, confidence = self._get_raw_item(key) + return (pred, + compute_polygon_section(self.baseline, self.line, cut[0], cut[1]), + confidence) + else: + raise TypeError('Invalid argument type') + + @property + def cuts(self) -> Sequence[Tuple[int, int]]: + return tuple([compute_polygon_section(self.baseline, self.line, cut[0], cut[1]) for cut in self._cuts]) + + def logical_order(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': + """ + Returns the OCR record in Unicode logical order, i.e. in the order the + characters in the line would be read by a human. + + Args: + base_dir: An optional string defining the base direction (also + called paragraph direction) for the BiDi algorithm. Valid + values are 'L' or 'R'. If None is given the default + auto-resolution will be used. + """ + if self._display_order: + return self._reorder(base_dir) + else: + return self + + def display_order(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': + """ + Returns the OCR record in Unicode display order, i.e. ordered from left + to right inside the line. + + Args: + base_dir: An optional string defining the base direction (also + called paragraph direction) for the BiDi algorithm. Valid + values are 'L' or 'R'. If None is given the default + auto-resolution will be used. + """ + if self._display_order: + return self + else: + return self._reorder(base_dir) + + def _reorder(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': + """ + Reorder the record using the BiDi algorithm. + """ + storage = bd.get_empty_storage() + + if base_dir not in ('L', 'R'): + base_level = bd.get_base_level(self._prediction) + else: + base_level = {'L': 0, 'R': 1}[base_dir] + + storage['base_level'] = base_level + storage['base_dir'] = ('L', 'R')[base_level] + bd.get_embedding_levels(self._prediction, storage) + bd.explicit_embed_and_overrides(storage) + bd.resolve_weak_types(storage) + bd.resolve_neutral_types(storage, False) + bd.resolve_implicit_levels(storage, False) + for i, j in enumerate(zip(self._prediction, self._cuts, self._confidences)): + storage['chars'][i]['record'] = j + bd.reorder_resolved_levels(storage, False) + bd.apply_mirroring(storage, False) + prediction = '' + cuts = [] + confidences = [] + for ch in storage['chars']: + # code point may have been mirrored + prediction = prediction + ch['ch'] + cuts.append(ch['record'][1]) + confidences.append(ch['record'][2]) + line = BaselineLine(id=self.id, + baseline=self.baseline, + boundary=self.boundary, + text=self.text, + base_dir=self._line_base_dir, + image=self.image) + rec = BaselineOCRRecord(prediction=prediction, + cuts=cuts, + confidences=confidences, + line=line, + base_dir=base_dir, + display_order=not self._display_order) + return rec + + +class BBoxOCRRecord(ocr_record, BBoxLine): + """ + A record object containing the recognition result of a single line in + bbox format. + """ + type = 'bbox' + + def __init__(self, prediction: str, + cuts: Sequence[Tuple[Tuple[int, int], + Tuple[int, int], + Tuple[int, int], + Tuple[int, int]]], + confidences: Sequence[float], + line: BBoxLine, + base_dir: Optional['L', 'R'], + display_order: bool = True) -> None: + if line.type != 'bbox': + raise TypeError('Invalid argument type (non-bbox line)') + BBoxLine.__init__(self, **asdict(line)) + self._line_base_dir = self.base_dir + self.base_dir = base_dir + ocr_record.__init__(self, prediction, cuts, confidences, display_order) + + def __repr__(self) -> str: + return f'pred: {self.prediction} line: {self.line} confidences: {self.confidences}' + + def __next__(self) -> Tuple[str, int, float]: + if self.idx + 1 < len(self): + self.idx += 1 + return (self.prediction[self.idx], + self.cuts[self.idx], + self.confidences[self.idx]) + else: + raise StopIteration + + def _get_raw_item(self, key: int): + if key < 0: + key += len(self) + if key >= len(self): + raise IndexError('Index (%d) is out of range' % key) + return (self.prediction[key], + self.cuts[key], + self.confidences[key]) + + def __getitem__(self, key: Union[int, slice]): + if isinstance(key, slice): + recs = [self._get_raw_item(i) for i in range(*key.indices(len(self)))] + prediction = ''.join([x[0] for x in recs]) + box = [x[1] for x in recs] + flat_box = [point for pol in box for point in pol] + flat_box = [x for point in flat_box for x in point] + min_x, max_x = min(flat_box[::2]), max(flat_box[::2]) + min_y, max_y = min(flat_box[1::2]), max(flat_box[1::2]) + cut = ((min_x, min_y), (max_x, min_y), (max_x, max_y), (min_x, max_y)) + confidence = np.mean([x[2] for x in recs]) + return (prediction, cut, confidence) + elif isinstance(key, int): + return self._get_raw_item(key) + else: + raise TypeError('Invalid argument type') + + def logical_order(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BBoxOCRRecord': + """ + Returns the OCR record in Unicode logical order, i.e. in the order the + characters in the line would be read by a human. + + Args: + base_dir: An optional string defining the base direction (also + called paragraph direction) for the BiDi algorithm. Valid + values are 'L' or 'R'. If None is given the default + auto-resolution will be used. + """ + if self._display_order: + return self._reorder(base_dir) + else: + return self + + def display_order(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BBoxOCRRecord': + """ + Returns the OCR record in Unicode display order, i.e. ordered from left + to right inside the line. + + Args: + base_dir: An optional string defining the base direction (also + called paragraph direction) for the BiDi algorithm. Valid + values are 'L' or 'R'. If None is given the default + auto-resolution will be used. + """ + if self._display_order: + return self + else: + return self._reorder(base_dir) + + def _reorder(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BBoxOCRRecord': + storage = bd.get_empty_storage() + + if base_dir not in ('L', 'R'): + base_level = bd.get_base_level(self.prediction) + else: + base_level = {'L': 0, 'R': 1}[base_dir] + + storage['base_level'] = base_level + storage['base_dir'] = ('L', 'R')[base_level] + + bd.get_embedding_levels(self.prediction, storage) + bd.explicit_embed_and_overrides(storage) + bd.resolve_weak_types(storage) + bd.resolve_neutral_types(storage, False) + bd.resolve_implicit_levels(storage, False) + for i, j in enumerate(zip(self.prediction, self.cuts, self.confidences)): + storage['chars'][i]['record'] = j + bd.reorder_resolved_levels(storage, False) + bd.apply_mirroring(storage, False) + prediction = '' + cuts = [] + confidences = [] + for ch in storage['chars']: + # code point may have been mirrored + prediction = prediction + ch['ch'] + cuts.append(ch['record'][1]) + confidences.append(ch['record'][2]) + # carry over whole line information + line = BBoxLine(id=self.id, + bbox=self.bbox, + text=self.text, + base_dir=self._line_base_dir, + image=self.image) + rec = BBoxOCRRecord(prediction=prediction, + cuts=cuts, + confidences=confidences, + line=line, + base_dir=base_dir, + display_order=not self._display_order) + return rec + + From 5afb82847354576b71dddff123bfba952b5a387e Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Thu, 11 May 2023 13:09:13 +0200 Subject: [PATCH 37/68] better _to_ptl_device --- kraken/ketos/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/ketos/util.py b/kraken/ketos/util.py index e71b53505..009190027 100644 --- a/kraken/ketos/util.py +++ b/kraken/ketos/util.py @@ -54,7 +54,7 @@ def message(msg, **styles): def to_ptl_device(device: str) -> Tuple[str, Optional[List[int]]]: - if any([device == x for x in ['cpu', 'mps']]): + if device in ['cpu', 'mps']]): return device, 'auto' elif any([device.startswith(x) for x in ['tpu', 'cuda', 'hpu', 'ipu']]): dev, idx = device.split(':') From 24a0b85bf318ab921f85935b7b3b502db63e13d2 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Thu, 11 May 2023 13:10:14 +0200 Subject: [PATCH 38/68] strip container classes from rpred --- kraken/rpred.py | 356 +----------------------------------------------- 1 file changed, 2 insertions(+), 354 deletions(-) diff --git a/kraken/rpred.py b/kraken/rpred.py index 33375cc12..40bcba765 100644 --- a/kraken/rpred.py +++ b/kraken/rpred.py @@ -28,6 +28,7 @@ from collections import defaultdict from typing import List, Tuple, Optional, Generator, Union, Dict, Sequence +from kraken.containers import BaselineOCRRecord, BBoxOCRRecord, ocr_record from kraken.lib.util import get_im_str, is_bitonal from kraken.lib.models import TorchSeqRecognizer from kraken.lib.segmentation import extract_polygons, compute_polygon_section, Segmentation @@ -36,364 +37,11 @@ import copy -__all__ = ['ocr_record', 'BaselineOCRRecord', 'BBoxOCRRecord', 'mm_rpred', 'rpred'] +__all__ = ['mm_rpred', 'rpred'] logger = logging.getLogger(__name__) -class ocr_record(ABC): - """ - A record object containing the recognition result of a single line - """ - base_dir = None - - def __init__(self, - prediction: str, - cuts: Sequence[Union[Tuple[int, int], Tuple[Tuple[int, int], - Tuple[int, int], - Tuple[int, int], - Tuple[int, int]]]], - confidences: Sequence[float], - display_order: bool = True) -> None: - self._prediction = prediction - self._cuts = cuts - self._confidences = confidences - self._display_order = display_order - - @property - @abstractmethod - def type(self): - pass - - def __len__(self) -> int: - return len(self._prediction) - - def __str__(self) -> str: - return self._prediction - - @property - def prediction(self) -> str: - return self._prediction - - @property - def cuts(self) -> Sequence: - return self._cuts - - @property - def confidences(self) -> List[float]: - return self._confidences - - def __iter__(self): - self.idx = -1 - return self - - @abstractmethod - def __next__(self) -> Tuple[str, - Union[Sequence[Tuple[int, int]], - Tuple[Tuple[int, int], - Tuple[int, int], - Tuple[int, int], - Tuple[int, int]]], - float]: - pass - - @abstractmethod - def __getitem__(self, key: Union[int, slice]): - pass - - @abstractmethod - def display_order(self, base_dir) -> 'ocr_record': - pass - - @abstractmethod - def logical_order(self, base_dir) -> 'ocr_record': - pass - - -class BaselineOCRRecord(ocr_record): - """ - A record object containing the recognition result of a single line in - baseline format. - - Attributes: - type: 'baselines' to indicate a baseline record - prediction: The text predicted by the network as one continuous string. - cuts: The absolute bounding polygons for each code point in prediction - as a list of tuples [(x0, y0), (x1, y2), ...]. - confidences: A list of floats indicating the confidence value of each - code point. - - Notes: - When slicing the record the behavior of the cuts is changed from - earlier versions of kraken. Instead of returning per-character bounding - polygons a single polygons section of the line bounding polygon - starting at the first and extending to the last code point emitted by - the network is returned. This aids numerical stability when computing - aggregated bounding polygons such as for words. Individual code point - bounding polygons are still accessible through the `cuts` attribute or - by iterating over the record code point by code point. - """ - type = 'baselines' - - def __init__(self, prediction: str, - cuts: Sequence[Tuple[int, int]], - confidences: Sequence[float], - line: Dict[str, List], - display_order: bool = True) -> None: - super().__init__(prediction, cuts, confidences, display_order) - if 'baseline' not in line: - raise TypeError('Invalid argument type (non-baseline line)') - self.tags = None if 'tags' not in line else line['tags'] - self.line = line['boundary'] - self.baseline = line['baseline'] - - def __repr__(self) -> str: - return f'pred: {self.prediction} baseline: {self.baseline} boundary: {self.line} confidences: {self.confidences}' - - def __next__(self) -> Tuple[str, int, float]: - if self.idx + 1 < len(self): - self.idx += 1 - return (self.prediction[self.idx], - compute_polygon_section(self.baseline, - self.line, - self.cuts[self.idx][0], - self.cuts[self.idx][1]), - self.confidences[self.idx]) - else: - raise StopIteration - - def _get_raw_item(self, key: int): - if key < 0: - key += len(self) - if key >= len(self): - raise IndexError('Index (%d) is out of range' % key) - return (self.prediction[key], - self._cuts[key], - self.confidences[key]) - - def __getitem__(self, key: Union[int, slice]): - if isinstance(key, slice): - recs = [self._get_raw_item(i) for i in range(*key.indices(len(self)))] - prediction = ''.join([x[0] for x in recs]) - flat_offsets = sum((tuple(x[1]) for x in recs), ()) - cut = compute_polygon_section(self.baseline, - self.line, - min(flat_offsets), - max(flat_offsets)) - confidence = np.mean([x[2] for x in recs]) - return (prediction, cut, confidence) - - elif isinstance(key, int): - pred, cut, confidence = self._get_raw_item(key) - return (pred, - compute_polygon_section(self.baseline, self.line, cut[0], cut[1]), - confidence) - else: - raise TypeError('Invalid argument type') - - @property - def cuts(self) -> Sequence[Tuple[int, int]]: - return tuple([compute_polygon_section(self.baseline, self.line, cut[0], cut[1]) for cut in self._cuts]) - - def logical_order(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': - """ - Returns the OCR record in Unicode logical order, i.e. in the order the - characters in the line would be read by a human. - - Args: - base_dir: An optional string defining the base direction (also - called paragraph direction) for the BiDi algorithm. Valid - values are 'L' or 'R'. If None is given the default - auto-resolution will be used. - """ - if self._display_order: - return self._reorder(base_dir) - else: - return self - - def display_order(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': - """ - Returns the OCR record in Unicode display order, i.e. ordered from left - to right inside the line. - - Args: - base_dir: An optional string defining the base direction (also - called paragraph direction) for the BiDi algorithm. Valid - values are 'L' or 'R'. If None is given the default - auto-resolution will be used. - """ - if self._display_order: - return self - else: - return self._reorder(base_dir) - - def _reorder(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': - """ - Reorder the record using the BiDi algorithm. - """ - storage = bd.get_empty_storage() - - if base_dir not in ('L', 'R'): - base_level = bd.get_base_level(self._prediction) - else: - base_level = {'L': 0, 'R': 1}[base_dir] - - storage['base_level'] = base_level - storage['base_dir'] = ('L', 'R')[base_level] - - bd.get_embedding_levels(self._prediction, storage) - bd.explicit_embed_and_overrides(storage) - bd.resolve_weak_types(storage) - bd.resolve_neutral_types(storage, False) - bd.resolve_implicit_levels(storage, False) - for i, j in enumerate(zip(self._prediction, self._cuts, self._confidences)): - storage['chars'][i]['record'] = j - bd.reorder_resolved_levels(storage, False) - bd.apply_mirroring(storage, False) - prediction = '' - cuts = [] - confidences = [] - for ch in storage['chars']: - # code point may have been mirrored - prediction = prediction + ch['ch'] - cuts.append(ch['record'][1]) - confidences.append(ch['record'][2]) - line = {'boundary': self.line, 'baseline': self.baseline} - rec = BaselineOCRRecord(prediction, cuts, confidences, line) - rec.tags = self.tags - rec.base_dir = base_dir - rec._display_order = not self._display_order - return rec - - -class BBoxOCRRecord(ocr_record): - """ - A record object containing the recognition result of a single line in - bbox format. - """ - type = 'box' - - def __init__(self, prediction: str, - cuts: Sequence[Tuple[Tuple[int, int], - Tuple[int, int], - Tuple[int, int], - Tuple[int, int]]], - confidences: Sequence[float], - line: Tuple[Tuple[int, int], - Tuple[int, int], - Tuple[int, int], - Tuple[int, int]], - display_order: bool = True) -> None: - super().__init__(prediction, cuts, confidences, display_order) - if 'baseline' in line: - raise TypeError('Invalid argument type (baseline line)') - self.line = line - - def __repr__(self) -> str: - return f'pred: {self.prediction} line: {self.line} confidences: {self.confidences}' - - def __next__(self) -> Tuple[str, int, float]: - if self.idx + 1 < len(self): - self.idx += 1 - return (self.prediction[self.idx], - self.cuts[self.idx], - self.confidences[self.idx]) - else: - raise StopIteration - - def _get_raw_item(self, key: int): - if key < 0: - key += len(self) - if key >= len(self): - raise IndexError('Index (%d) is out of range' % key) - return (self.prediction[key], - self.cuts[key], - self.confidences[key]) - - def __getitem__(self, key: Union[int, slice]): - if isinstance(key, slice): - recs = [self._get_raw_item(i) for i in range(*key.indices(len(self)))] - prediction = ''.join([x[0] for x in recs]) - box = [x[1] for x in recs] - flat_box = [point for pol in box for point in pol] - flat_box = [x for point in flat_box for x in point] - min_x, max_x = min(flat_box[::2]), max(flat_box[::2]) - min_y, max_y = min(flat_box[1::2]), max(flat_box[1::2]) - cut = ((min_x, min_y), (max_x, min_y), (max_x, max_y), (min_x, max_y)) - confidence = np.mean([x[2] for x in recs]) - return (prediction, cut, confidence) - elif isinstance(key, int): - return self._get_raw_item(key) - else: - raise TypeError('Invalid argument type') - - def logical_order(self, base_dir: Optional[str] = None) -> 'BBoxOCRRecord': - """ - Returns the OCR record in Unicode logical order, i.e. in the order the - characters in the line would be read by a human. - - Args: - base_dir: An optional string defining the base direction (also - called paragraph direction) for the BiDi algorithm. Valid - values are 'L' or 'R'. If None is given the default - auto-resolution will be used. - """ - if self._display_order: - return self._reorder(base_dir) - else: - return self - - def display_order(self, base_dir: Optional[str] = None) -> 'BBoxOCRRecord': - """ - Returns the OCR record in Unicode display order, i.e. ordered from left - to right inside the line. - - Args: - base_dir: An optional string defining the base direction (also - called paragraph direction) for the BiDi algorithm. Valid - values are 'L' or 'R'. If None is given the default - auto-resolution will be used. - """ - if self._display_order: - return self - else: - return self._reorder(base_dir) - - def _reorder(self, base_dir: Optional[str] = None) -> 'BBoxOCRRecord': - storage = bd.get_empty_storage() - - if base_dir not in ('L', 'R'): - base_level = bd.get_base_level(self.prediction) - else: - base_level = {'L': 0, 'R': 1}[base_dir] - - storage['base_level'] = base_level - storage['base_dir'] = ('L', 'R')[base_level] - - bd.get_embedding_levels(self.prediction, storage) - bd.explicit_embed_and_overrides(storage) - bd.resolve_weak_types(storage) - bd.resolve_neutral_types(storage, False) - bd.resolve_implicit_levels(storage, False) - for i, j in enumerate(zip(self.prediction, self.cuts, self.confidences)): - storage['chars'][i]['record'] = j - bd.reorder_resolved_levels(storage, False) - bd.apply_mirroring(storage, False) - prediction = '' - cuts = [] - confidences = [] - for ch in storage['chars']: - # code point may have been mirrored - prediction = prediction + ch['ch'] - cuts.append(ch['record'][1]) - confidences.append(ch['record'][2]) - # carry over whole line information - rec = BBoxOCRRecord(prediction, cuts, confidences, self.line) - rec.base_dir = base_dir - rec._display_order = not self._display_order - return rec - - class mm_rpred(object): """ Multi-model version of kraken.rpred.rpred From 66127aaf6e1ad3b55f1cac82e9caf9e061fddea6 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Thu, 11 May 2023 13:19:47 +0200 Subject: [PATCH 39/68] Use new container classes in XMLPage --- kraken/containers.py | 7 + kraken/lib/xml.py | 452 ++----------------------------------------- 2 files changed, 24 insertions(+), 435 deletions(-) diff --git a/kraken/containers.py b/kraken/containers.py index 5110ad659..a07848f92 100644 --- a/kraken/containers.py +++ b/kraken/containers.py @@ -5,6 +5,13 @@ from dataclasses import dataclass, asdict from abc import ABC, abstractmethod +__all__ = ['BaselineLine', + 'BBoxLine', + 'Segmentation', + 'ocr_record', + 'BaselineOCRRecord', + 'BBoxOCRRecord'] + @dataclass class BaselineLine: diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index 135c50560..ec55772e6 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -27,12 +27,14 @@ from typing import Union, Dict, Any, Sequence, Tuple, Literal, Optional, List from collections import defaultdict +from kraken.containers import BaselineLine from kraken.lib.segmentation import calculate_polygonal_environment from kraken.lib.exceptions import KrakenInputException logger = logging.getLogger(__name__) -__all__ = ['XMLPage', 'parse_xml', 'parse_page', 'parse_alto', 'preparse_xml_data'] +__all__ = ['XMLPage'] + # fallback mapping between PAGE region types and tags page_regions = {'TextRegion': 'text', @@ -58,427 +60,6 @@ 'ComposedBlock': 'composed'} -def preparse_xml_data(filenames: Sequence[Union[str, PathLike]], - format_type: str = 'xml', - repolygonize: bool = False) -> Dict[str, Any]: - """ - Loads training data from a set of xml files. - - Extracts line information from Page/ALTO xml files for training of - recognition models. - - Args: - filenames: List of XML files. - format_type: Either `page`, `alto` or `xml` for autodetermination. - repolygonize: (Re-)calculates polygon information using the kraken - algorithm. - - Returns: - A list of dicts {'text': text, 'baseline': [[x0, y0], ...], 'boundary': - [[x0, y0], ...], 'image': PIL.Image}. - """ - training_pairs = [] - if format_type == 'xml': - parse_fn = parse_xml - elif format_type == 'alto': - parse_fn = parse_alto - elif format_type == 'page': - parse_fn = parse_page - else: - raise ValueError(f'invalid format {format_type} for preparse_xml_data') - - for fn in filenames: - try: - data = parse_fn(fn) - except ValueError as e: - logger.warning(e) - continue - try: - with open(data['image'], 'rb') as fp: - Image.open(fp) - except FileNotFoundError as e: - logger.warning(f'Could not open file {e.filename} in {fn}') - continue - if repolygonize: - logger.info('repolygonizing {} lines in {}'.format(len(data['lines']), data['image'])) - data['lines'] = _repolygonize(data['image'], data['lines']) - for line in data['lines']: - training_pairs.append({'image': data['image'], **line}) - return training_pairs - - -def _repolygonize(im: Image.Image, lines: Sequence[Dict[str, Any]]): - """ - Helper function taking an output of the lib.xml parse_* functions and - recalculating the contained polygonization. - - Args: - im (Image.Image): Input image - lines (list): List of dicts [{'boundary': [[x0, y0], ...], 'baseline': [[x0, y0], ...], 'text': 'abcvsd'}, {...] - - Returns: - A data structure `lines` with a changed polygonization. - """ - im = Image.open(im).convert('L') - polygons = calculate_polygonal_environment(im, [x['baseline'] for x in lines]) - return [{'boundary': polygon, - 'baseline': orig['baseline'], - 'text': orig['text'], - 'script': orig['script']} for orig, polygon in zip(lines, polygons)] - - -def parse_xml(filename: Union[str, PathLike]) -> Dict[str, Any]: - """ - Parses either a PageXML or ALTO file with autodetermination of the file - format. - - Args: - filename: path to an XML file. - - Returns: - A dict:: - - {'image': impath, - 'lines': [{'boundary': [[x0, y0], ...], - 'baseline': [[x0, y0], ...], - 'text': apdjfqpf', - 'tags': {'type': 'default', ...}}, - ... - {...}], - 'regions': {'region_type_0': [[[x0, y0], ...], ...], ...}} - """ - with open(filename, 'rb') as fp: - try: - doc = etree.parse(fp) - except etree.XMLSyntaxError as e: - raise KrakenInputException(f'Parsing {filename} failed: {e}') - if doc.getroot().tag.endswith('alto'): - return parse_alto(filename) - elif doc.getroot().tag.endswith('PcGts'): - return parse_page(filename) - else: - raise KrakenInputException(f'Unknown XML format in {filename}') - - -def parse_page(filename: Union[str, PathLike]) -> Dict[str, Any]: - """ - Parses a PageXML file, returns the baselines defined in it, and loads the - referenced image. - - Args: - filename: path to a PageXML file. - - Returns: - A dict:: - - {'image': impath, - 'lines': [{'boundary': [[x0, y0], ...], - 'baseline': [[x0, y0], ...], - 'text': apdjfqpf', - 'tags': {'type': 'default', ...}}, - ... - {...}], - 'regions': {'region_type_0': [[[x0, y0], ...], ...], ...}} - """ - def _parse_page_custom(s): - o = {} - s = s.strip() - l_chunks = [l_chunk for l_chunk in s.split('}') if l_chunk.strip()] - if l_chunks: - for chunk in l_chunks: - tag, vals = chunk.split('{') - tag_vals = {} - vals = [val.strip() for val in vals.split(';') if val.strip()] - for val in vals: - key, *val = val.split(':') - tag_vals[key] = ":".join(val) - o[tag.strip()] = tag_vals - return o - - def _parse_coords(coords): - points = [x for x in coords.split(' ')] - points = [int(c) for point in points for c in point.split(',')] - pts = zip(points[::2], points[1::2]) - return [k for k, g in groupby(pts)] - - with open(filename, 'rb') as fp: - base_dir = Path(filename).parent - try: - doc = etree.parse(fp) - except etree.XMLSyntaxError as e: - raise KrakenInputException('Parsing {} failed: {}'.format(filename, e)) - image = doc.find('.//{*}Page') - if image is None or image.get('imageFilename') is None: - raise KrakenInputException('No valid image filename found in PageXML file {}'.format(filename)) - try: - base_direction = {'left-to-right': 'L', - 'right-to-left': 'R', - 'top-to-bottom': 'L', - 'bottom-to-top': 'R', - None: None}[image.get('readingDirection')] - except KeyError: - logger.warning(f'Invalid value {image.get("readingDirection")} encountered in page-level reading direction.') - base_direction = None - lines = doc.findall('.//{*}TextLine') - data = {'image': base_dir.joinpath(image.get('imageFilename')), - 'lines': [], - 'type': 'baselines', - 'base_dir': base_direction, - 'regions': {}} - # find all image regions - regions = [] - for x in page_regions.keys(): - regions.extend(doc.findall('.//{{*}}{}'.format(x))) - # parse region type and coords - region_data = defaultdict(list) - for region in regions: - coords = region.find('{*}Coords') - if coords is not None and not coords.get('points').isspace() and len(coords.get('points')): - try: - coords = _parse_coords(coords.get('points')) - except Exception: - logger.warning('Region {} without coordinates'.format(region.get('id'))) - continue - else: - logger.warning('Region {} without coordinates'.format(region.get('id'))) - continue - rtype = region.get('type') - # parse transkribus-style custom field if possible - custom_str = region.get('custom') - if not rtype and custom_str: - cs = _parse_page_custom(custom_str) - if 'structure' in cs and 'type' in cs['structure']: - rtype = cs['structure']['type'] - # fall back to default region type if nothing is given - if not rtype: - rtype = page_regions[region.tag.split('}')[-1]] - region_data[rtype].append(coords) - - data['regions'] = region_data - - # parse line information - tag_set = set(('default',)) - for line in lines: - pol = line.find('./{*}Coords') - boundary = None - if pol is not None and not pol.get('points').isspace() and len(pol.get('points')): - try: - boundary = _parse_coords(pol.get('points')) - except Exception: - logger.info('TextLine {} without polygon'.format(line.get('id'))) - else: - logger.info('TextLine {} without polygon'.format(line.get('id'))) - base = line.find('./{*}Baseline') - baseline = None - if base is not None and not base.get('points').isspace() and len(base.get('points')): - try: - baseline = _parse_coords(base.get('points')) - except Exception: - logger.info('TextLine {} without baseline'.format(line.get('id'))) - continue - else: - logger.info('TextLine {} without baseline'.format(line.get('id'))) - continue - text = '' - manual_transcription = line.find('./{*}TextEquiv') - if manual_transcription is not None: - transcription = manual_transcription - else: - transcription = line - for el in transcription.findall('.//{*}Unicode'): - if el.text: - text += el.text - # retrieve line tags if custom string is set and contains - tags = {'type': 'default'} - split_type = None - custom_str = line.get('custom') - if custom_str: - cs = _parse_page_custom(custom_str) - if 'structure' in cs and 'type' in cs['structure']: - tags['type'] = cs['structure']['type'] - tag_set.add(tags['type']) - # retrieve data split if encoded in custom string. - if 'split' in cs and 'type' in cs['split'] and cs['split']['type'] in ['train', 'validation', 'test']: - split_type = cs['split']['type'] - tags['split'] = split_type - tag_set.add(split_type) - - data['lines'].append({'baseline': baseline, - 'boundary': boundary, - 'text': text, - 'split': split_type, - 'tags': tags}) - if len(tag_set) > 1: - data['script_detection'] = True - else: - data['script_detection'] = False - return data - - -def parse_alto(filename: Union[str, PathLike]) -> Dict[str, Any]: - """ - Parses an ALTO file, returns the baselines defined in it, and loads the - referenced image. - - Args: - filename: path to an ALTO file. - - Returns: - A dict:: - - {'image': impath, - 'lines': [{'boundary': [[x0, y0], ...], - 'baseline': [[x0, y0], ...], - 'text': apdjfqpf', - 'tags': {'type': 'default', ...}}, - ... - {...}], - 'regions': {'region_type_0': [[[x0, y0], ...], ...], ...}} - """ - def _parse_pointstype(coords: str) -> Sequence[Tuple[float, float]]: - """ - ALTO's PointsType is underspecified so a variety of serializations are valid: - - x0, y0 x1, y1 ... - x0 y0 x1 y1 ... - (x0, y0) (x1, y1) ... - (x0 y0) (x1 y1) ... - - Returns: - A list of tuples [(x0, y0), (x1, y1), ...] - """ - float_re = re.compile(r'[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?') - points = [float(point.group()) for point in float_re.finditer(coords)] - if len(points) % 2: - raise ValueError(f'Odd number of points in points sequence: {points}') - pts = zip(points[::2], points[1::2]) - return [k for k, g in groupby(pts)] - - with open(filename, 'rb') as fp: - base_dir = Path(filename).parent - try: - doc = etree.parse(fp) - except etree.XMLSyntaxError as e: - raise KrakenInputException('Parsing {} failed: {}'.format(filename, e)) - image = doc.find('.//{*}fileName') - if image is None or not image.text: - raise KrakenInputException('No valid filename found in ALTO file') - lines = doc.findall('.//{*}TextLine') - data = {'image': base_dir.joinpath(image.text), - 'lines': [], - 'type': 'baselines', - 'base_dir': None, - 'regions': {}} - # find all image regions - regions = [] - for x in alto_regions.keys(): - regions.extend(doc.findall('./{{*}}Layout/{{*}}Page/{{*}}PrintSpace/{{*}}{}'.format(x))) - # find overall dimensions to filter out dummy TextBlocks - ps = doc.find('./{*}Layout/{*}Page/{*}PrintSpace') - x_min = int(float(ps.get('HPOS'))) - y_min = int(float(ps.get('VPOS'))) - width = int(float(ps.get('WIDTH'))) - height = int(float(ps.get('HEIGHT'))) - page_boundary = [(x_min, y_min), - (x_min, y_min + height), - (x_min + width, y_min + height), - (x_min + width, y_min)] - - # parse tagrefs - cls_map = {} - tags = doc.find('.//{*}Tags') - if tags is not None: - for x in ['StructureTag', 'LayoutTag', 'OtherTag']: - for tag in tags.findall('./{{*}}{}'.format(x)): - cls_map[tag.get('ID')] = (x[:-3].lower(), tag.get('LABEL')) - # parse region type and coords - region_data = defaultdict(list) - for region in regions: - # try to find shape object - coords = region.find('./{*}Shape/{*}Polygon') - if coords is not None: - boundary = _parse_pointstype(coords.get('POINTS')) - elif (region.get('HPOS') is not None and region.get('VPOS') is not None and - region.get('WIDTH') is not None and region.get('HEIGHT') is not None): - # use rectangular definition - x_min = int(float(region.get('HPOS'))) - y_min = int(float(region.get('VPOS'))) - width = int(float(region.get('WIDTH'))) - height = int(float(region.get('HEIGHT'))) - boundary = [(x_min, y_min), - (x_min, y_min + height), - (x_min + width, y_min + height), - (x_min + width, y_min)] - else: - continue - rtype = region.get('TYPE') - # fall back to default region type if nothing is given - tagrefs = region.get('TAGREFS') - if tagrefs is not None and rtype is None: - for tagref in tagrefs.split(): - ttype, rtype = cls_map.get(tagref, (None, None)) - if rtype is not None and ttype: - break - if rtype is None: - rtype = alto_regions[region.tag.split('}')[-1]] - if boundary == page_boundary and rtype == 'text': - logger.info('Skipping TextBlock with same size as page image.') - continue - region_data[rtype].append({'id': region.get('ID'), 'boundary': boundary}) - data['regions'] = region_data - - tag_set = set(('default',)) - for line in lines: - if line.get('BASELINE') is None: - logger.info('TextLine {} without baseline'.format(line.get('ID'))) - continue - pol = line.find('./{*}Shape/{*}Polygon') - boundary = None - if pol is not None: - try: - boundary = _parse_pointstype(pol.get('POINTS')) - except ValueError: - logger.info('TextLine {} without polygon'.format(line.get('ID'))) - else: - logger.info('TextLine {} without polygon'.format(line.get('ID'))) - - baseline = None - try: - baseline = _parse_pointstype(line.get('BASELINE')) - except ValueError: - logger.info('TextLine {} without baseline'.format(line.get('ID'))) - - text = '' - for el in line.xpath(".//*[local-name() = 'String'] | .//*[local-name() = 'SP']"): - text += el.get('CONTENT') if el.get('CONTENT') else ' ' - # find line type - tags = {'type': 'default'} - split_type = None - tagrefs = line.get('TAGREFS') - if tagrefs is not None: - for tagref in tagrefs.split(): - ttype, ltype = cls_map.get(tagref, (None, None)) - if ltype is not None: - tag_set.add(ltype) - if ttype == 'other': - tags['type'] = ltype - else: - tags[ttype] = ltype - if ltype in ['train', 'validation', 'test']: - split_type = ltype - data['lines'].append({'baseline': baseline, - 'boundary': boundary, - 'text': text, - 'tags': tags, - 'split': split_type}) - - if len(tag_set) > 1: - data['tags'] = True - else: - data['tags'] = False - return data - - class XMLPage(object): type: Literal['baselines', 'bbox'] = 'baselines' @@ -629,12 +210,13 @@ def _parse_alto(self): tags[ttype] = ltype if ltype in ['train', 'validation', 'test']: split_type = ltype - self._lines[line.get('ID')] = {'baseline': baseline, - 'boundary': boundary, - 'text': text, - 'tags': tags, - 'split': split_type, - 'region': region_id} + self._lines[line.get('ID')] = BaselineLine(id=line.get('ID'), + baseline=baseline, + boundary=boundary, + text=text, + tags=tags, + split=split_type, + region=region_id) # register implicit reading order self._orders['line_implicit']['order'].append(line.get('ID')) @@ -804,13 +386,13 @@ def _parse_page(self): else: tmp_transkribus_line_order[int(reg_cus['readingOrder']['index'])].append((int(cs['readingOrder']['index']), line.get('id'))) - self._lines[line.get('id')] = {'baseline': baseline, - 'boundary': boundary, - 'text': text, - 'split': split_type, - 'tags': tags, - 'region': region.get('id')} - + self._lines[line.get('id')] = BaselineLine(id=line.get('id'), + baseline=baseline, + boundary=boundary, + text=text, + tags=tags, + split=split_type, + region=region.get('id')) # register implicit reading order self._orders['line_implicit']['order'].append(line.get('id')) From 9a787019fffa3eedbbb30cf6dd3250b097af0507 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 16 May 2023 13:51:13 +0200 Subject: [PATCH 40/68] Use new container classes in blla.segment --- kraken/blla.py | 46 +++++++++++++++++++++++++++----------- kraken/containers.py | 30 +++++++++++++++++++++---- kraken/lib/segmentation.py | 16 ++----------- 3 files changed, 61 insertions(+), 31 deletions(-) diff --git a/kraken/blla.py b/kraken/blla.py index 767ec6867..442b34079 100644 --- a/kraken/blla.py +++ b/kraken/blla.py @@ -21,6 +21,7 @@ """ import PIL +import uuid import torch import logging import numpy as np @@ -37,8 +38,7 @@ from kraken.lib import vgsl, dataset from kraken.lib.util import is_bitonal, get_im_str from kraken.lib.exceptions import KrakenInputException, KrakenInvalidModelException -from kraken.lib.segmentation import (Segmentation, - polygonal_reading_order, +from kraken.lib.segmentation import (polygonal_reading_order, neural_reading_order, vectorize_lines, vectorize_regions, scale_polygonal_lines, @@ -134,7 +134,7 @@ def compute_segmentation_map(im: PIL.Image.Image, 'scal_im': scal_im} -def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs) -> Dict[str, List[List[Tuple[int, int]]]]: +def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs) -> Dict[str, List[Region]]: """ Computes regions from a stack of heatmaps, a class mapping, and scaling factor. @@ -154,8 +154,8 @@ def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs) -> for region_type, idx in cls_map['regions'].items(): logger.debug(f'Vectorizing regions of type {region_type}') regions[region_type] = vectorize_regions(heatmap[idx]) - for reg_id, regs in regions.items(): - regions[reg_id] = scale_regions(regs, scale) + for reg_type, regs in regions.items(): + regions[reg_type] = [Region(id=uuid.uuid4(), boundary=x, tags={'type': reg_type}) for x in scale_regions(regs, scale)] return regions @@ -218,6 +218,7 @@ def vec_lines(heatmap: torch.Tensor, lines = [] reg_pols = [geom.Polygon(x) for x in regions] + line_regs = [] for bl_idx in range(len(baselines)): bl = baselines[bl_idx] mid_point = geom.LineString(bl[1]).interpolate(0.5, normalized=True) @@ -226,7 +227,6 @@ def vec_lines(heatmap: torch.Tensor, for reg_idx, reg_pol in enumerate(reg_pols): if reg_pol.contains(mid_point): suppl_obj.append(regions[reg_idx]) - pol = calculate_polygonal_environment( baselines=[bl[1]], im_feats=im_feats, @@ -239,7 +239,7 @@ def vec_lines(heatmap: torch.Tensor, logger.debug('Scaling vectorized lines') sc = scale_polygonal_lines([x[1:] for x in lines], scale) - lines = list(zip([x[0] for x in lines], [x[0] for x in sc], [x[1] for x in sc])) + lines = list(zip([x[0] for x in lines], [x[0] for x in sc], [x[1] for x in sc], line_regs)) return [{'tags': {'type': bl_type}, 'baseline': bl, 'boundary': pl} for bl_type, bl, pl in lines] @@ -250,7 +250,7 @@ def segment(im: PIL.Image.Image, model: Union[List[vgsl.TorchVGSLModel], vgsl.TorchVGSLModel] = None, device: str = 'cpu', raise_on_error: bool = False, - autocast: bool = False) -> Dict[str, Any]: + autocast: bool = False) -> Segmentation: r""" Segments a page into text lines using the baseline segmenter. @@ -337,6 +337,10 @@ def segment(im: PIL.Image.Image, logger.debug(f'Baseline location: {loc}') rets = compute_segmentation_map(im, mask, net, device, autocast=autocast) _regions = vec_regions(**rets) + for reg_key, reg_val in vec_regions(**rets).items(): + if reg_key not in regions: + regions[reg_key] = [] + regions[reg_key].extend(reg_val) # flatten regions for line ordering/fetch bounding regions line_regs = [] @@ -346,8 +350,8 @@ def segment(im: PIL.Image.Image, if rets['bounding_regions'] is not None and cls in rets['bounding_regions']: suppl_obj.extend(regs) # convert back to net scale - suppl_obj = scale_regions(suppl_obj, 1/rets['scale']) - line_regs = scale_regions(line_regs, 1/rets['scale']) + suppl_obj = scale_regions([x.boundary for x in suppl_obj], 1/rets['scale']) + line_regs = scale_regions([x.boundary for x in line_regs], 1/rets['scale']) _lines = vec_lines(**rets, regions=line_regs, @@ -359,7 +363,7 @@ def segment(im: PIL.Image.Image, if 'ro_model' in net.aux_layers: logger.info(f'Using reading order model found in segmentation model {net}.') _order = neural_reading_order(lines=_lines, - regions=regions, + regions=_regions, text_direction=text_direction[-2:], model=net.aux_layers['ro_model'], im_size=im.size, @@ -388,9 +392,25 @@ def segment(im: PIL.Image.Image, else: script_detection = False + # create objects and assign IDs + blls = [] + reg_idx = 0 + _shp_regs = {} + for reg_type, rgs in regions.items(): + for reg in rgs: + _shp_regs[reg.id] = geom.Polygon(reg.boundary) + + for idx, line in enumerate(lines): + line_regs = [] + for reg_id, reg in _shp_regs.items(): + mid_point = geom.LineString(line[1]).interpolate(0.5, normalized=True) + if reg.contains(mid_point): + line_regs.append(reg_id) + blls.append(BaselineLine(id=f'line_{idx}', baseline=line[1], boundary=line[2], tags={'type': line[0]}, regions=line_regs)) + return Segmentation(text_direction=text_direction, type='baselines', - lines=lines, - regions=regions, + lines=blls, + regions=_regs, script_detection=script_detection, line_orders=[order]) diff --git a/kraken/containers.py b/kraken/containers.py index a07848f92..3c12f2f0b 100644 --- a/kraken/containers.py +++ b/kraken/containers.py @@ -24,7 +24,9 @@ class BaselineLine: base_dir: Optional[Literal['L', 'R']] = None type: str = 'baselines' image: Optional[PIL.Image.Image] = None - + tags: Optional[Dict[str, str]] = None + split: Optional[Literal['train', 'validation', 'test'] = None + regions: Optional[List[str]] = None @dataclass class BBoxLine: @@ -39,6 +41,20 @@ class BBoxLine: base_dir: Optional[Literal['L', 'R']] = None type: str = 'bbox' image: Optional[PIL.Image.Image] = None + tags: Optional[Dict[str, str]] = None + split: Optional[Literal['train', 'validation', 'test'] = None + regions: Optional[List[str]] = None + + +@dataclass +class Region: + """ + + """ + id: str + boundary: List[Tuple[int, int]] + image: Optional[PIL.Image.Image] = None + tags: Optional[Dict[str, str]] = None @dataclass @@ -51,7 +67,7 @@ class Segmentation: text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] script_detection: bool lines: Sequence[Union[BaselineLine, BBoxLine]] - regions: Dict[str, List] + regions: Dict[str, List[Region]] line_orders: List[List[int]] @@ -276,7 +292,10 @@ def _reorder(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': boundary=self.boundary, text=self.text, base_dir=self._line_base_dir, - image=self.image) + image=self.image, + tags=self.tags, + split=self.split, + region=self.region) rec = BaselineOCRRecord(prediction=prediction, cuts=cuts, confidences=confidences, @@ -412,7 +431,10 @@ def _reorder(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BBoxOCRReco bbox=self.bbox, text=self.text, base_dir=self._line_base_dir, - image=self.image) + image=self.image, + tags=self.tags, + split=self.split, + region=self.region) rec = BBoxOCRRecord(prediction=prediction, cuts=cuts, confidences=confidences, diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index d821d9ce6..8f6b9ccff 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -22,7 +22,6 @@ import shapely.geometry as geom import torch.nn.functional as F -from dataclasses import dataclass from collections import defaultdict from PIL import Image @@ -43,6 +42,7 @@ from typing import List, Tuple, Union, Dict, Any, Sequence, Optional, Literal +from kraken.containers import Segmentation, BaselineLine, BBoxLine from kraken.lib import default_specs from kraken.lib.exceptions import KrakenInputException @@ -60,19 +60,7 @@ 'scale_polygonal_lines', 'scale_regions', 'compute_polygon_section', - 'extract_polygons', - 'Segmentation'] - - -@dataclass -class Segmentation: - type: Literal['baselines', 'bbox'] - imagename: str - text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] - script_detection: bool - lines: List - regions: Dict[str, List] - line_orders: List[List[int]] + 'extract_polygons'] def reading_order(lines: Sequence[Tuple[slice, slice]], text_direction: Literal['lr', 'rl'] = 'lr') -> np.ndarray: From 83b17fae9f40ec12692f5f39533dcf3f8b742c0d Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 16 May 2023 13:55:51 +0200 Subject: [PATCH 41/68] BBoxLine/Segmentation in legacy segmenter --- kraken/pageseg.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kraken/pageseg.py b/kraken/pageseg.py index 355af4f5d..8bf96b64c 100644 --- a/kraken/pageseg.py +++ b/kraken/pageseg.py @@ -29,7 +29,7 @@ from kraken.lib import morph, sl from kraken.lib.util import pil2array, is_bitonal, get_im_str from kraken.lib.exceptions import KrakenInputException -from kraken.lib.segmentation import reading_order, topsort, Segmentation +from kraken.lib.segmentation import reading_order, topsort, Segmentation, BBoxLine __all__ = ['segment'] @@ -423,10 +423,11 @@ def segment(im, if isinstance(pad, int): pad = (pad, pad) lines = [(max(x[0]-pad[0], 0), x[1], min(x[2]+pad[1], im.size[0]), x[3]) for x in lines] + lines = [BBoxLine(id=uuid.uuid4(), bbox=line) for line in rotate_lines(lines, 360-angle, offset).tolist()] return Segmentation(text_direction=text_direction, type='bbox', regions=None, line_orders=None, - lines=rotate_lines(lines, 360-angle, offset).tolist(), + lines=lines, script_detection=False) From c390e68c6f4455d739e3895f379aa1645cad236f Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 16 May 2023 13:56:12 +0200 Subject: [PATCH 42/68] UUIDs in blla lines --- kraken/blla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kraken/blla.py b/kraken/blla.py index 442b34079..84ce62086 100644 --- a/kraken/blla.py +++ b/kraken/blla.py @@ -400,13 +400,13 @@ def segment(im: PIL.Image.Image, for reg in rgs: _shp_regs[reg.id] = geom.Polygon(reg.boundary) - for idx, line in enumerate(lines): + for line in lines: line_regs = [] for reg_id, reg in _shp_regs.items(): mid_point = geom.LineString(line[1]).interpolate(0.5, normalized=True) if reg.contains(mid_point): line_regs.append(reg_id) - blls.append(BaselineLine(id=f'line_{idx}', baseline=line[1], boundary=line[2], tags={'type': line[0]}, regions=line_regs)) + blls.append(BaselineLine(id=uuid.uuid4(), baseline=line[1], boundary=line[2], tags={'type': line[0]}, regions=line_regs)) return Segmentation(text_direction=text_direction, type='baselines', From b9753de1dd181b493c810f229cb1823b493fdd60 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 16 May 2023 14:48:56 +0200 Subject: [PATCH 43/68] Segmentation/BBoxLine/BaselineLine containers in rpred --- kraken/rpred.py | 171 +++++++++++++++++++++++------------------------- 1 file changed, 83 insertions(+), 88 deletions(-) diff --git a/kraken/rpred.py b/kraken/rpred.py index 40bcba765..3354b56b6 100644 --- a/kraken/rpred.py +++ b/kraken/rpred.py @@ -109,14 +109,13 @@ def __init__(self, if bounds.type == 'baselines': valid_norm = False self.next_iter = self._recognize_baseline_line - tags = set() - for x in bounds.lines: - tags.update(x['tags'].values()) else: valid_norm = True - self.seg_key = 'boxes' self.next_iter = self._recognize_box_line - tags = set(x[0] for line in bounds.lines for x in line) + + tags = set() + for x in bounds.lines: + tags.update(x.tags.values()) im_str = get_im_str(im) logger.info(f'Running {len(nets)} multi-script recognizers on {im_str} with {self.len} lines') @@ -149,70 +148,75 @@ def __init__(self, self.tags_ignore = tags_ignore def _recognize_box_line(self, line): - flat_box = [point for box in line['boxes'][0] for point in box[1]] + flat_box = [point for box in line.bbox for point in box] xmin, xmax = min(flat_box[::2]), max(flat_box[::2]) ymin, ymax = min(flat_box[1::2]), max(flat_box[1::2]) line_bbox = ((xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)) prediction = '' cuts = [] confidences = [] - for tag, (box, coords) in zip(map(lambda x: x[0], line['boxes'][0]), - extract_polygons(self.im, {'text_direction': line['text_direction'], - 'boxes': map(lambda x: x[1], line['boxes'][0])})): - self.box = box - # skip if tag is set to ignore - if self.tags_ignore is not None and tag in self.tags_ignore: - logger.warning(f'Ignoring {tag} line segment.') - continue - # check if boxes are non-zero in any dimension - if 0 in box.size: - logger.warning(f'bbox {coords} with zero dimension. Emitting empty record.') - return BBoxOCRRecord('', (), (), coords) - # try conversion into tensor - try: - logger.debug('Preparing run.') - line = self.ts[tag](box) - except Exception: - logger.warning(f'Conversion of line {coords} failed. Emitting empty record..') - return BBoxOCRRecord('', (), (), coords) - - # check if line is non-zero - if line.max() == line.min(): - logger.warning('Empty run. Emitting empty record.') - return BBoxOCRRecord('', (), (), coords) - - _, net = self._resolve_tags_to_model({'type': tag}, self.nets) - - logger.debug(f'Forward pass with model {tag}.') - preds = net.predict(line.unsqueeze(0))[0] - - # calculate recognized LSTM locations of characters - logger.debug('Convert to absolute coordinates') - # calculate recognized LSTM locations of characters - # scale between network output and network input - self.net_scale = line.shape[2]/net.outputs.shape[2] - # scale between network input and original line - self.in_scale = box.size[0]/(line.shape[2]-2*self.pad) - - pred = ''.join(x[0] for x in preds) - pos = [] - conf = [] - - for _, start, end, c in preds: - if self.bounds.text_direction.startswith('horizontal'): - xmin = coords[0] + self._scale_val(start, 0, self.box.size[0]) - xmax = coords[0] + self._scale_val(end, 0, self.box.size[0]) - pos.append([[xmin, coords[1]], [xmin, coords[3]], [xmax, coords[3]], [xmax, coords[1]]]) - else: - ymin = coords[1] + self._scale_val(start, 0, self.box.size[1]) - ymax = coords[1] + self._scale_val(end, 0, self.box.size[1]) - pos.append([[coords[0], ymin], [coords[2], ymin], [coords[2], ymax], [coords[0], ymax]]) - conf.append(c) - prediction += pred - cuts.extend(pos) - confidences.extend(conf) - - rec = BBoxOCRRecord(prediction, cuts, confidences, line_bbox) + line.text_direction = self.bounds.text_direction + + if self.tags_ignore is not None: + for tag in line.tags.values(): + if tag in self.tags_ignore: + logger.info(f'Ignoring line segment with tags {line.tags} based on {tag}.') + return BaselineOCRRecord('', [], [], line) + + tag, net = self._resolve_tags_to_model(line.tags, self.nets) + + box, coords = next(extract_polygons(self.im, line)) + self.box = box + + # check if boxes are non-zero in any dimension + if 0 in box.size: + logger.warning(f'bbox {line} with zero dimension. Emitting empty record.') + return BBoxOCRRecord('', (), (), line) + # try conversion into tensor + try: + logger.debug('Preparing run.') + ts_box = self.ts[tag](box) + except Exception: + logger.warning(f'Conversion of line {line} failed. Emitting empty record..') + return BBoxOCRRecord('', (), (), line) + + # check if line is non-zero + if ts_box.max() == ts_box.min(): + logger.warning('Empty run. Emitting empty record.') + return BBoxOCRRecord('', (), (), line) + + _, net = self._resolve_tags_to_model({'type': tag}, self.nets) + + logger.debug(f'Forward pass with model {tag}.') + preds = net.predict(ts_box.unsqueeze(0))[0] + + # calculate recognized LSTM locations of characters + logger.debug('Convert to absolute coordinates') + # calculate recognized LSTM locations of characters + # scale between network output and network input + self.net_scale = ts_box.shape[2]/net.outputs.shape[2] + # scale between network input and original line + self.in_scale = box.size[0]/(ts_box.shape[2]-2*self.pad) + + pred = ''.join(x[0] for x in preds) + pos = [] + conf = [] + + for _, start, end, c in preds: + if self.bounds.text_direction.startswith('horizontal'): + xmin = coords[0] + self._scale_val(start, 0, self.box.size[0]) + xmax = coords[0] + self._scale_val(end, 0, self.box.size[0]) + pos.append([[xmin, coords[1]], [xmin, coords[3]], [xmax, coords[3]], [xmax, coords[1]]]) + else: + ymin = coords[1] + self._scale_val(start, 0, self.box.size[1]) + ymax = coords[1] + self._scale_val(end, 0, self.box.size[1]) + pos.append([[coords[0], ymin], [coords[2], ymin], [coords[2], ymax], [coords[0], ymax]]) + conf.append(c) + prediction += pred + cuts.extend(pos) + confidences.extend(conf) + + rec = BBoxOCRRecord(prediction, cuts, confidences, line) if self.bidi_reordering: logger.debug('BiDi reordering record.') return rec.logical_order(base_dir=self.bidi_reordering if self.bidi_reordering in ('L', 'R') else None) @@ -222,41 +226,41 @@ def _recognize_box_line(self, line): def _recognize_baseline_line(self, line): if self.tags_ignore is not None: - for tag in line['lines'][0]['tags'].values(): + for tag in line.tags.values(): if tag in self.tags_ignore: - logger.info(f'Ignoring line segment with tags {line["lines"][0]["tags"]} based on {tag}.') - return BaselineOCRRecord('', [], [], line['lines'][0]) + logger.info(f'Ignoring line segment with tags {line.tags} based on {tag}.') + return BaselineOCRRecord('', [], [], line) try: box, coords = next(extract_polygons(self.im, line)) except KrakenInputException as e: logger.warning(f'Extracting line failed: {e}') - return BaselineOCRRecord('', [], [], line['lines'][0]) + return BaselineOCRRecord('', [], [], line) self.box = box - tag, net = self._resolve_tags_to_model(coords['tags'], self.nets) + tag, net = self._resolve_tags_to_model(line.tags, self.nets) # check if boxes are non-zero in any dimension if 0 in box.size: - logger.warning(f'bbox {coords} with zero dimension. Emitting empty record.') - return BaselineOCRRecord('', [], [], coords) + logger.warning(f'{line} with zero dimension. Emitting empty record.') + return BaselineOCRRecord('', [], [], line) # try conversion into tensor try: - line = self.ts[tag](box) + ts_box = self.ts[tag](box) except Exception as e: logger.warning(f'Tensor conversion failed with {e}. Emitting empty record.') - return BaselineOCRRecord('', [], [], coords) + return BaselineOCRRecord('', [], [], line) # check if line is non-zero - if line.max() == line.min(): + if ts_box.max() == ts_box.min(): logger.warning('Empty line after tensor conversion. Emitting empty record.') - return BaselineOCRRecord('', [], [], coords) + return BaselineOCRRecord('', [], [], line) - preds = net.predict(line.unsqueeze(0))[0] + preds = net.predict(ts_box.unsqueeze(0))[0] # calculate recognized LSTM locations of characters # scale between network output and network input - self.net_scale = line.shape[2]/net.outputs.shape[2] + self.net_scale = ts_box.shape[2]/net.outputs.shape[2] # scale between network input and original line - self.in_scale = box.size[0]/(line.shape[2]-2*self.pad) + self.in_scale = box.size[0]/(ts_box.shape[2]-2*self.pad) # XXX: fix bounding box calculation ocr_record for multi-codepoint labels. pred = ''.join(x[0] for x in preds) @@ -266,7 +270,7 @@ def _recognize_baseline_line(self, line): pos.append((self._scale_val(start, 0, self.box.size[0]), self._scale_val(end, 0, self.box.size[0]))) conf.append(c) - rec = BaselineOCRRecord(pred, pos, conf, coords) + rec = BaselineOCRRecord(pred, pos, conf, line) if self.bidi_reordering: logger.debug('BiDi reordering record.') return rec.logical_order(base_dir=self.bidi_reordering if self.bidi_reordering in ('L', 'R') else None) @@ -276,8 +280,7 @@ def _recognize_baseline_line(self, line): def __next__(self): bound = self.bounds - setattr(bound, self.seg_key, [next(self.line_iter)]) - return self.next_iter(bound) + return self.next_iter(next(self.line_iter)) def __iter__(self): return self @@ -312,14 +315,6 @@ def rpred(network: TorchSeqRecognizer, An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. """ - bounds = copy.deepcopy(bounds) - if bounds.type == 'bbox': - boxes = bounds.lines - rewrite_boxes = [] - for box in boxes: - rewrite_boxes.append([('default', box)]) - bounds.lines = rewrite_boxes - bounds.script_detection = True return mm_rpred(defaultdict(lambda: network), im, bounds, pad, bidi_reordering) From 11a3a199f98471692adbe5c1b9bfe8ada60973e0 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 16 May 2023 16:39:21 +0200 Subject: [PATCH 44/68] Add to_container() method to XMLPage --- kraken/containers.py | 19 ++++++++++--------- kraken/lib/xml.py | 14 +++++++++++++- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/kraken/containers.py b/kraken/containers.py index 3c12f2f0b..f09bd40ba 100644 --- a/kraken/containers.py +++ b/kraken/containers.py @@ -25,7 +25,7 @@ class BaselineLine: type: str = 'baselines' image: Optional[PIL.Image.Image] = None tags: Optional[Dict[str, str]] = None - split: Optional[Literal['train', 'validation', 'test'] = None + split: Optional[Literal['train', 'validation', 'test']] = None regions: Optional[List[str]] = None @dataclass @@ -42,9 +42,9 @@ class BBoxLine: type: str = 'bbox' image: Optional[PIL.Image.Image] = None tags: Optional[Dict[str, str]] = None - split: Optional[Literal['train', 'validation', 'test'] = None + split: Optional[Literal['train', 'validation', 'test']] = None regions: Optional[List[str]] = None - + text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] = 'horizontal-lr' @dataclass class Region: @@ -68,7 +68,7 @@ class Segmentation: script_detection: bool lines: Sequence[Union[BaselineLine, BBoxLine]] regions: Dict[str, List[Region]] - line_orders: List[List[int]] + line_orders: Optional[List[List[int]]] = None class ocr_record(ABC): @@ -225,7 +225,7 @@ def __getitem__(self, key: Union[int, slice]): def cuts(self) -> Sequence[Tuple[int, int]]: return tuple([compute_polygon_section(self.baseline, self.line, cut[0], cut[1]) for cut in self._cuts]) - def logical_order(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': + def logical_order(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BaselineOCRRecord': """ Returns the OCR record in Unicode logical order, i.e. in the order the characters in the line would be read by a human. @@ -241,7 +241,7 @@ def logical_order(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': else: return self - def display_order(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': + def display_order(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BaselineOCRRecord': """ Returns the OCR record in Unicode display order, i.e. ordered from left to right inside the line. @@ -257,7 +257,7 @@ def display_order(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': else: return self._reorder(base_dir) - def _reorder(self, base_dir: Optional[str] = None) -> 'BaselineOCRRecord': + def _reorder(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BaselineOCRRecord': """ Reorder the record using the BiDi algorithm. """ @@ -312,14 +312,15 @@ class BBoxOCRRecord(ocr_record, BBoxLine): """ type = 'bbox' - def __init__(self, prediction: str, + def __init__(self, + prediction: str, cuts: Sequence[Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[int, int]]], confidences: Sequence[float], line: BBoxLine, - base_dir: Optional['L', 'R'], + base_dir: Optional[Literal['L', 'R']], display_order: bool = True) -> None: if line.type != 'bbox': raise TypeError('Invalid argument type (non-bbox line)') diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index ec55772e6..01870cd5c 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -27,7 +27,7 @@ from typing import Union, Dict, Any, Sequence, Tuple, Literal, Optional, List from collections import defaultdict -from kraken.containers import BaselineLine +from kraken.containers import Segmentation, BaselineLine from kraken.lib.segmentation import calculate_polygonal_environment from kraken.lib.exceptions import KrakenInputException @@ -584,3 +584,15 @@ def __str__(self): def __repr__(self): return f'XMLPage(filename={self.filename}, filetype={self.filetype})' + + def to_container(self) -> Segmentation: + """ + Returns a Segmentation object. + """ + return Segmentation(type='baselines', + imagename=self.imagename, + text_direction='horizontal_lr', + script_detection=True, + lines=self.get_sorted_lines(), + regions=self._regions, + line_orders=None) From 4210874e2643f107123d1b74ab9b7d93f19a1381 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 22 May 2023 14:31:27 +0200 Subject: [PATCH 45/68] make compilation work with new container objects --- kraken/lib/arrow_dataset.py | 50 ++++++++++++++++++++++--------------- kraken/lib/segmentation.py | 6 ++--- kraken/lib/xml.py | 20 +++++++-------- 3 files changed, 43 insertions(+), 33 deletions(-) diff --git a/kraken/lib/arrow_dataset.py b/kraken/lib/arrow_dataset.py index 149c58590..4ca12d9bb 100755 --- a/kraken/lib/arrow_dataset.py +++ b/kraken/lib/arrow_dataset.py @@ -29,8 +29,8 @@ from typing import Optional, List, Union, Callable, Tuple, Dict from multiprocessing import Pool from kraken.lib import functional_im_transforms as F_t -from kraken.lib.segmentation import extract_polygons -from kraken.lib.xml import parse_xml, parse_alto, parse_page +from kraken.lib.segmentation import extract_polygons, Segmentation +from kraken.lib.xml import XMLPage from kraken.lib.util import is_bitonal, make_printable from kraken.lib.exceptions import KrakenInputException from os import extsep, PathLike @@ -43,27 +43,33 @@ def _extract_line(xml_record, skip_empty_lines: bool = True): lines = [] try: - im = Image.open(xml_record['image']) + im = Image.open(xml_record.imagename) except (FileNotFoundError, UnidentifiedImageError): return lines, None, None if is_bitonal(im): im = im.convert('1') - seg_key = 'lines' if 'lines' in xml_record else 'boxes' - recs = xml_record.pop(seg_key) + recs = xml_record.lines.values() for idx, rec in enumerate(recs): + seg = Segmentation(text_direction='horizontal-lr', + imagename=xml_record.imagename, + type=xml_record.type, + lines=[rec], + regions=None, + script_detection=False, + line_orders=None) try: - line_im, line = next(extract_polygons(im, {**xml_record, seg_key: [rec]})) + line_im, line = next(extract_polygons(im, seg)) except KrakenInputException: logger.warning(f'Invalid line {idx} in {im.filename}') continue except Exception as e: logger.warning(f'Unexpected exception {e} from line {idx} in {im.filename}') continue - if not line['text'] and skip_empty_lines: + if not line.text and skip_empty_lines: continue fp = io.BytesIO() line_im.save(fp, format='png') - lines.append({'text': line['text'], 'im': fp.getvalue()}) + lines.append({'text': line.text, 'im': fp.getvalue()}) return lines, im.mode @@ -90,6 +96,7 @@ def parse_path(path: Union[str, PathLike], gt = fp.read().strip('\n\r') if not gt and skip_empty_lines: raise KrakenInputException(f'No text for ground truth line {path}.') + return {'image': path, 'lines': [{'text': gt}]} @@ -135,12 +142,8 @@ def build_binary_dataset(files: Optional[List[Union[str, PathLike, Dict]]] = Non logger.info('Parsing XML files') extract_fn = partial(_extract_line, skip_empty_lines=skip_empty_lines) parse_fn = None - if format_type == 'xml': - parse_fn = parse_xml - elif format_type == 'alto': - parse_fn = parse_alto - elif format_type == 'page': - parse_fn = parse_page + if format_type in ['xml', 'alto', 'page']: + parse_fn = XMLPage elif format_type == 'path': if not ignore_splits: logger.warning('ignore_splits is False and format_type is path. Will not serialize splits.') @@ -163,10 +166,13 @@ def build_binary_dataset(files: Optional[List[Union[str, PathLike, Dict]]] = Non logger.warning(f'Invalid input file {doc}') continue try: - name_ext = str(data['image']).split(extsep, 1) - if name_ext[1] == 'gt.txt': - data['image'] = name_ext[0] + '.png' - with open(data['image'], 'rb') as fp: + if format_type in ['xml', 'alto', 'page']: + imagename = data.imagename + else: + name_ext = str(data['image']).split(extsep, 1) + imagename = name_ext[0] + '.png' + data['image'] = imagename + with open(imagename, 'rb') as fp: Image.open(fp) except (FileNotFoundError, UnidentifiedImageError) as e: logger.warning(f'Could not open file {e.filename} in {doc}') @@ -181,9 +187,13 @@ def build_binary_dataset(files: Optional[List[Union[str, PathLike, Dict]]] = Non alphabet = Counter() num_lines = 0 for doc in docs: - for line in doc['lines']: + if format_type in ['xml', 'alto', 'page']: + lines = doc.lines.values() + else: + lines = doc['lines'] + for line in lines: num_lines += 1 - alphabet.update(line['text']) + alphabet.update(line.text) callback(0, num_lines) diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 8f6b9ccff..34e101cdd 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -1058,10 +1058,10 @@ def extract_polygons(im: Image.Image, bounds: Segmentation) -> Image.Image: im = np.array(im) for line in bounds.lines: - if line['boundary'] is None: + if line.boundary is None: raise KrakenInputException('No boundary given for line') - pl = np.array(line['boundary']) - baseline = np.array(line['baseline']) + pl = np.array(line.boundary) + baseline = np.array(line.baseline) c_min, c_max = int(pl[:, 0].min()), int(pl[:, 0].max()) r_min, r_max = int(pl[:, 1].min()), int(pl[:, 1].max()) diff --git a/kraken/lib/xml.py b/kraken/lib/xml.py index 01870cd5c..3f60fa7ee 100644 --- a/kraken/lib/xml.py +++ b/kraken/lib/xml.py @@ -27,7 +27,7 @@ from typing import Union, Dict, Any, Sequence, Tuple, Literal, Optional, List from collections import defaultdict -from kraken.containers import Segmentation, BaselineLine +from kraken.containers import Segmentation, BaselineLine, Region from kraken.lib.segmentation import calculate_polygonal_environment from kraken.lib.exceptions import KrakenInputException @@ -167,7 +167,7 @@ def _parse_alto(self): if rtype is None: rtype = alto_regions[region.tag.split('}')[-1]] region_id = region.get('ID') - region_data[rtype].append({'id': region_id, 'boundary': boundary}) + region_data[rtype].append(Region(id=region_id, boundary=coords, tags={'type': rtype})) # register implicit reading order self._orders['region_implicit']['order'].append(region_id) @@ -216,7 +216,7 @@ def _parse_alto(self): text=text, tags=tags, split=split_type, - region=region_id) + regions=[region_id]) # register implicit reading order self._orders['line_implicit']['order'].append(line.get('ID')) @@ -328,7 +328,7 @@ def _parse_page(self): # fall back to default region type if nothing is given if not rtype: rtype = page_regions[region.tag.split('}')[-1]] - region_data[rtype].append({'id': region.get('id'), 'boundary': coords}) + region_data[rtype].append(Region(id=region.get('id'), boundary=coords, tags={'type': rtype})) # register implicit reading order self._orders['region_implicit']['order'].append(region.get('id')) @@ -392,7 +392,7 @@ def _parse_page(self): text=text, tags=tags, split=split_type, - region=region.get('id')) + regions=[region.get('id')]) # register implicit reading order self._orders['line_implicit']['order'].append(line.get('id')) @@ -491,7 +491,7 @@ def get_sorted_regions(self, ro='region_implicit'): if ro not in self.reading_orders: raise ValueError(f'Unknown reading order {ro}') - regions = {reg['id']: key for key, regs in self.regions.items() for reg in regs} + regions = {reg.id: key for key, regs in self.regions.items() for reg in regs} def _traverse_ro(el): _ro = [] @@ -500,7 +500,7 @@ def _traverse_ro(el): else: # if region directly append to ro if el in regions.keys(): - return [reg for reg in self.regions[regions[el]] if reg['id'] == el][0] + return [reg for reg in self.regions[regions[el]] if reg.id == el][0] else: raise ValueError(f'Invalid reading order {ro}') return _ro @@ -516,17 +516,17 @@ def get_sorted_lines_by_region(self, region, ro='line_implicit'): raise ValueError(f'Unknown reading order {ro}') if self.reading_orders[ro]['is_total'] is False: raise ValueError('Fetching lines by region of a non-total order is not supported') - lines = [(id, line) for id, line in self._lines.items() if line['region'] == region] + lines = [(id, line) for id, line in self._lines.items() if line.regions[0] == region] for line in lines: if line[0] not in self.reading_orders[ro]['order']: raise ValueError('Fetching lines by region is only possible for flat orders') return sorted(lines, key=lambda k: self.reading_orders[ro]['order'].index(k[0])) def get_lines_by_tag(self, key, value): - return {k: v for k, v in self._lines.items() if v['tags'].get(key) == value} + return {k: v for k, v in self._lines.items() if v.tags.get(key) == value} def get_lines_by_split(self, split: Literal['train', 'validation', 'test']): - return {k: v for k, v in self._lines.items() if v['tags'].get('split') == split} + return {k: v for k, v in self._lines.items() if v.tags.get('split') == split} @property def tags(self): From 09ad27eee7a0d476d6b941b4461d6018b5199046 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 22 May 2023 14:31:50 +0200 Subject: [PATCH 46/68] Fix imports in blla --- kraken/blla.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kraken/blla.py b/kraken/blla.py index 84ce62086..d27ac071b 100644 --- a/kraken/blla.py +++ b/kraken/blla.py @@ -36,6 +36,7 @@ from skimage.filters import sobel from kraken.lib import vgsl, dataset +from kraken.containers import Region, Segmentation from kraken.lib.util import is_bitonal, get_im_str from kraken.lib.exceptions import KrakenInputException, KrakenInvalidModelException from kraken.lib.segmentation import (polygonal_reading_order, From b491122e0cb755a22a7b978d9ee0fad60f50010a Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 22 May 2023 14:32:09 +0200 Subject: [PATCH 47/68] typo in ketos utils --- kraken/ketos/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/ketos/util.py b/kraken/ketos/util.py index 009190027..b37b298b2 100644 --- a/kraken/ketos/util.py +++ b/kraken/ketos/util.py @@ -54,7 +54,7 @@ def message(msg, **styles): def to_ptl_device(device: str) -> Tuple[str, Optional[List[int]]]: - if device in ['cpu', 'mps']]): + if device in ['cpu', 'mps']: return device, 'auto' elif any([device.startswith(x) for x in ['tpu', 'cuda', 'hpu', 'ipu']]): dev, idx = device.split(':') From a9cfaf095d6a31020412574f703fce5befaba5d6 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 23 May 2023 14:53:42 +0200 Subject: [PATCH 48/68] make path compilation work again --- kraken/lib/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/lib/arrow_dataset.py b/kraken/lib/arrow_dataset.py index 4ca12d9bb..694fd4279 100755 --- a/kraken/lib/arrow_dataset.py +++ b/kraken/lib/arrow_dataset.py @@ -193,7 +193,7 @@ def build_binary_dataset(files: Optional[List[Union[str, PathLike, Dict]]] = Non lines = doc['lines'] for line in lines: num_lines += 1 - alphabet.update(line.text) + alphabet.update(line.text if format_type in ['xml', 'alto', 'page'] else line['text']) callback(0, num_lines) From 7e8fcf794b4e50040c266f12694b3aeb4228be23 Mon Sep 17 00:00:00 2001 From: Colin Brisson Date: Sat, 17 Jun 2023 14:17:27 +0000 Subject: [PATCH 49/68] add __getitem__ to the Baseline class --- kraken/containers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/kraken/containers.py b/kraken/containers.py index f09bd40ba..1d48698a3 100644 --- a/kraken/containers.py +++ b/kraken/containers.py @@ -28,6 +28,9 @@ class BaselineLine: split: Optional[Literal['train', 'validation', 'test']] = None regions: Optional[List[str]] = None + def __getitem__(self, item): + return getattr(self, item) + @dataclass class BBoxLine: """ From 8f34a9ae6086c2a465591f86cfe648b42c41f4bb Mon Sep 17 00:00:00 2001 From: Colin Brisson Date: Sat, 17 Jun 2023 14:18:09 +0000 Subject: [PATCH 50/68] remove preparse_xml_data import statement --- kraken/lib/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 253075253..81296eb27 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -33,7 +33,7 @@ from pytorch_lightning.callbacks import Callback, EarlyStopping, BaseFinetuning, LearningRateMonitor from kraken.lib import models, vgsl, default_specs, progress -from kraken.lib.xml import preparse_xml_data +# from kraken.lib.xml import preparse_xml_data from kraken.lib.util import make_printable from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import (ArrowIPCRecognitionDataset, BaselineSet, From fe897c7da36142b1d9ddb9be976a03f43e2efb3c Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 2 Jul 2023 11:12:20 +0200 Subject: [PATCH 51/68] Container classes in segmentation --- kraken/blla.py | 46 +++++++++++++++++++------------------- kraken/kraken.py | 7 ++---- kraken/lib/segmentation.py | 18 +++++++-------- kraken/pageseg.py | 4 +++- 4 files changed, 36 insertions(+), 39 deletions(-) diff --git a/kraken/blla.py b/kraken/blla.py index d27ac071b..6d65de794 100644 --- a/kraken/blla.py +++ b/kraken/blla.py @@ -36,7 +36,7 @@ from skimage.filters import sobel from kraken.lib import vgsl, dataset -from kraken.containers import Region, Segmentation +from kraken.containers import Region, Segmentation, BaselineLine from kraken.lib.util import is_bitonal, get_im_str from kraken.lib.exceptions import KrakenInputException, KrakenInvalidModelException from kraken.lib.segmentation import (polygonal_reading_order, @@ -44,6 +44,7 @@ vectorize_lines, vectorize_regions, scale_polygonal_lines, calculate_polygonal_environment, + is_in_region, scale_regions) __all__ = ['segment'] @@ -156,7 +157,7 @@ def vec_regions(heatmap: torch.Tensor, cls_map: Dict, scale: float, **kwargs) -> logger.debug(f'Vectorizing regions of type {region_type}') regions[region_type] = vectorize_regions(heatmap[idx]) for reg_type, regs in regions.items(): - regions[reg_type] = [Region(id=uuid.uuid4(), boundary=x, tags={'type': reg_type}) for x in scale_regions(regs, scale)] + regions[reg_type] = [Region(id=str(uuid.uuid4()), boundary=x, tags={'type': reg_type}) for x in scale_regions(regs, scale)] return regions @@ -205,6 +206,7 @@ def vec_lines(heatmap: torch.Tensor, ... ] """ + st_sep = cls_map['aux']['_start_separator'] end_sep = cls_map['aux']['_end_separator'] @@ -219,28 +221,25 @@ def vec_lines(heatmap: torch.Tensor, lines = [] reg_pols = [geom.Polygon(x) for x in regions] - line_regs = [] for bl_idx in range(len(baselines)): bl = baselines[bl_idx] - mid_point = geom.LineString(bl[1]).interpolate(0.5, normalized=True) - + bl_ls = geom.LineString(bl[1]) suppl_obj = [x[1] for x in baselines[:bl_idx] + baselines[bl_idx+1:]] for reg_idx, reg_pol in enumerate(reg_pols): - if reg_pol.contains(mid_point): + if is_in_region(bl_ls, reg_pol): suppl_obj.append(regions[reg_idx]) - pol = calculate_polygonal_environment( - baselines=[bl[1]], - im_feats=im_feats, - suppl_obj=suppl_obj, - topline=topline, - raise_on_error=raise_on_error - ) + pol = calculate_polygonal_environment(baselines=[bl[1]], + im_feats=im_feats, + suppl_obj=suppl_obj, + topline=topline, + raise_on_error=raise_on_error) if pol[0] is not None: lines.append((bl[0], bl[1], pol[0])) logger.debug('Scaling vectorized lines') sc = scale_polygonal_lines([x[1:] for x in lines], scale) - lines = list(zip([x[0] for x in lines], [x[0] for x in sc], [x[1] for x in sc], line_regs)) + + lines = list(zip([x[0] for x in lines], [x[0] for x in sc], [x[1] for x in sc])) return [{'tags': {'type': bl_type}, 'baseline': bl, 'boundary': pl} for bl_type, bl, pl in lines] @@ -383,11 +382,6 @@ def segment(im: PIL.Image.Image, lines.extend(_lines) - # reorder lines - logger.debug(f'Reordering baselines with main RO function {reading_order_fn}.') - basic_lo = reading_order_fn(lines=lines, regions=regions, text_direction=text_direction[-2:]) - lines = [lines[idx] for idx in basic_lo] - if len(rets['cls_map']['baselines']) > 1: script_detection = True else: @@ -401,17 +395,23 @@ def segment(im: PIL.Image.Image, for reg in rgs: _shp_regs[reg.id] = geom.Polygon(reg.boundary) + # reorder lines + logger.debug(f'Reordering baselines with main RO function {reading_order_fn}.') + basic_lo = reading_order_fn(lines=lines, regions=_shp_regs.values(), text_direction=text_direction[-2:]) + lines = [lines[idx] for idx in basic_lo] + for line in lines: line_regs = [] for reg_id, reg in _shp_regs.items(): - mid_point = geom.LineString(line[1]).interpolate(0.5, normalized=True) - if reg.contains(mid_point): + line_ls = geom.LineString(line['baseline']) + if is_in_region(line_ls, reg): line_regs.append(reg_id) - blls.append(BaselineLine(id=uuid.uuid4(), baseline=line[1], boundary=line[2], tags={'type': line[0]}, regions=line_regs)) + blls.append(BaselineLine(id=str(uuid.uuid4()), baseline=line['baseline'], boundary=line['boundary'], tags=line['tags'], regions=line_regs)) return Segmentation(text_direction=text_direction, + imagename=getattr(im, 'filename', None), type='baselines', lines=blls, - regions=_regs, + regions=regions, script_detection=script_detection, line_orders=[order]) diff --git a/kraken/kraken.py b/kraken/kraken.py index a45e47f75..24a7a62da 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -46,7 +46,6 @@ APP_NAME = 'kraken' SEGMENTATION_DEFAULT_MODEL = pkg_resources.resource_filename(__name__, 'blla.mlmodel') DEFAULT_MODEL = ['en_best.mlmodel'] -LEGACY_MODEL_DIR = '/usr/local/share/ocropus' # raise default max image size to 20k * 20k pixels Image.MAX_IMAGE_PIXELS = 20000 ** 2 @@ -581,14 +580,12 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction, thr if reorder and base_dir != 'auto': reorder = base_dir - # first we try to find the model in the absolue path, then ~/.kraken, then - # LEGACY_MODEL_DIR + # first we try to find the model in the absolue path, then ~/.kraken nm = {} # type: Dict[str, models.TorchSeqRecognizer] ign_tags = model.pop('ignore') for k, v in model.items(): search = [v, - os.path.join(click.get_app_dir(APP_NAME), v), - os.path.join(LEGACY_MODEL_DIR, v)] + os.path.join(click.get_app_dir(APP_NAME), v)] location = None for loc in search: if os.path.isfile(loc): diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index 34e101cdd..ef07142d3 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -740,7 +740,7 @@ def calculate_polygonal_environment(im: PIL.Image.Image = None, def polygonal_reading_order(lines: Sequence[Dict], text_direction: Literal['lr', 'rl'] = 'lr', - regions: Optional[Sequence[List[Tuple[int, int]]]] = None) -> Sequence[int]: + regions: Optional[Sequence[geom.Polygon]] = None) -> Sequence[int]: """ Given a list of baselines and regions, calculates the correct reading order and applies it to the input. @@ -758,16 +758,14 @@ def polygonal_reading_order(lines: Sequence[Dict], lines = [(line['tags']['type'], line['baseline'], line['boundary']) for line in lines] bounds = [] - if regions is not None: - r = [geom.Polygon(reg) for reg in regions] - else: - r = [] - region_lines = [[] for _ in range(len(r))] + if regions is None: + regions = [] + region_lines = [[] for _ in range(len(regions))] indizes = {} for line_idx, line in enumerate(lines): s_line = geom.LineString(line[1]) in_region = False - for idx, reg in enumerate(r): + for idx, reg in enumerate(regions): if is_in_region(s_line, reg): region_lines[idx].append((line_idx, (slice(s_line.bounds[1], s_line.bounds[3]), slice(s_line.bounds[0], s_line.bounds[2])))) @@ -778,8 +776,8 @@ def polygonal_reading_order(lines: Sequence[Dict], slice(s_line.bounds[0], s_line.bounds[2]))) indizes[line_idx] = ('line', line) # order everything in regions - intra_region_order = [[] for _ in range(len(r))] - for idx, reg in enumerate(r): + intra_region_order = [[] for _ in range(len(regions))] + for idx, reg in enumerate(regions): if len(region_lines[idx]) > 0: order = reading_order([x[1] for x in region_lines[idx]], text_direction) lsort = topsort(order) @@ -819,7 +817,7 @@ def is_in_region(line, region) -> bool: def neural_reading_order(lines: Sequence[Dict], text_direction: str = 'lr', - regions: Optional[Sequence[List[Tuple[int, int]]]] = None, + regions: Optional[Sequence[geom.Polygon]] = None, im_size: Tuple[int, int] = None, model: 'TorchVGSLModel' = None, class_mapping: Dict[str, int] = None) -> Sequence[int]: diff --git a/kraken/pageseg.py b/kraken/pageseg.py index 8bf96b64c..2fcd8f823 100644 --- a/kraken/pageseg.py +++ b/kraken/pageseg.py @@ -19,6 +19,7 @@ Layout analysis methods. """ +import uuid import logging import numpy as np @@ -423,9 +424,10 @@ def segment(im, if isinstance(pad, int): pad = (pad, pad) lines = [(max(x[0]-pad[0], 0), x[1], min(x[2]+pad[1], im.size[0]), x[3]) for x in lines] - lines = [BBoxLine(id=uuid.uuid4(), bbox=line) for line in rotate_lines(lines, 360-angle, offset).tolist()] + lines = [BBoxLine(id=str(uuid.uuid4()), bbox=line) for line in rotate_lines(lines, 360-angle, offset).tolist()] return Segmentation(text_direction=text_direction, + imagename=getattr(im, 'filename', None), type='bbox', regions=None, line_orders=None, From 2cffebd183f6ce72f4f7b00073c3a957d882a5cf Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 4 Jul 2023 01:11:09 +0200 Subject: [PATCH 52/68] autoinstantiate baselineline/bboxline when loading segmentation from json --- kraken/containers.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/kraken/containers.py b/kraken/containers.py index f09bd40ba..ee378d1e9 100644 --- a/kraken/containers.py +++ b/kraken/containers.py @@ -1,7 +1,7 @@ import PIL.Image -from typing import Literal, List, Dict, Sequence, Union, Optional, Tuple +from typing import Literal, List, Dict, Union, Optional, Tuple from dataclasses import dataclass, asdict from abc import ABC, abstractmethod @@ -66,10 +66,15 @@ class Segmentation: imagename: str text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] script_detection: bool - lines: Sequence[Union[BaselineLine, BBoxLine]] + lines: List[Union[BaselineLine, BBoxLine]] regions: Dict[str, List[Region]] line_orders: Optional[List[List[int]]] = None + def __post_init__(self): + if len(self.lines) and not isinstance(self.lines[0], BBoxLine) and not isinstance(self.lines[0], BaselineLine): + line_cls = BBoxLine if self.type == 'bbox' else BaselineLine + self.lines = [line_cls(**line) for line in self.lines] + class ocr_record(ABC): """ @@ -79,11 +84,11 @@ class ocr_record(ABC): def __init__(self, prediction: str, - cuts: Sequence[Union[Tuple[int, int], Tuple[Tuple[int, int], + cuts: List[Union[Tuple[int, int], Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[int, int]]]], - confidences: Sequence[float], + confidences: List[float], display_order: bool = True) -> None: self._prediction = prediction self._cuts = cuts @@ -106,7 +111,7 @@ def prediction(self) -> str: return self._prediction @property - def cuts(self) -> Sequence: + def cuts(self) -> List: return self._cuts @property @@ -119,7 +124,7 @@ def __iter__(self): @abstractmethod def __next__(self) -> Tuple[str, - Union[Sequence[Tuple[int, int]], + Union[List[Tuple[int, int]], Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], @@ -166,8 +171,8 @@ class BaselineOCRRecord(ocr_record, BaselineLine): type = 'baselines' def __init__(self, prediction: str, - cuts: Sequence[Tuple[int, int]], - confidences: Sequence[float], + cuts: List[Tuple[int, int]], + confidences: List[float], line: BaselineLine, base_dir: Optional[Literal['L', 'R']] = None, display_order: bool = True) -> None: @@ -222,7 +227,7 @@ def __getitem__(self, key: Union[int, slice]): raise TypeError('Invalid argument type') @property - def cuts(self) -> Sequence[Tuple[int, int]]: + def cuts(self) -> List[Tuple[int, int]]: return tuple([compute_polygon_section(self.baseline, self.line, cut[0], cut[1]) for cut in self._cuts]) def logical_order(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BaselineOCRRecord': @@ -314,11 +319,11 @@ class BBoxOCRRecord(ocr_record, BBoxLine): def __init__(self, prediction: str, - cuts: Sequence[Tuple[Tuple[int, int], + cuts: List[Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[int, int]]], - confidences: Sequence[float], + confidences: List[float], line: BBoxLine, base_dir: Optional[Literal['L', 'R']], display_order: bool = True) -> None: From 20a869ae89b6f73eab8e9072ede87b63078c4a8b Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 4 Jul 2023 13:48:25 +0200 Subject: [PATCH 53/68] Use new containers in rpred --- kraken/containers.py | 3 ++- kraken/kraken.py | 31 ++++++++++++++----------------- kraken/rpred.py | 23 +++++++++++++---------- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/kraken/containers.py b/kraken/containers.py index ee378d1e9..0a2e13af0 100644 --- a/kraken/containers.py +++ b/kraken/containers.py @@ -1,5 +1,6 @@ import PIL.Image +import bidi.algorithm as bd from typing import Literal, List, Dict, Union, Optional, Tuple from dataclasses import dataclass, asdict @@ -300,7 +301,7 @@ def _reorder(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BaselineOCR image=self.image, tags=self.tags, split=self.split, - region=self.region) + regions=self.regions) rec = BaselineOCRRecord(prediction=prediction, cuts=cuts, confidences=confidences, diff --git a/kraken/kraken.py b/kraken/kraken.py index 24a7a62da..cb76614d5 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -180,7 +180,9 @@ def segmenter(legacy, model, text_direction, scale, maxcolseps, black_colseps, def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, output) -> None: import json + import uuid + from kraken.containers import Segmentation, BBoxLine from kraken import rpred ctx = click.get_current_context() @@ -192,17 +194,11 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, if ctx.meta['first_process']: if ctx.meta['input_format_type'] != 'image': doc = get_input_parser(ctx.meta['input_format_type'])(input) - ctx.meta['base_image'] = doc['image'] - doc['text_direction'] = 'horizontal-lr' - if doc['base_dir'] and bidi_reordering is True: - message(f'Setting base text direction for BiDi reordering to {doc["base_dir"]} (from XML input file)') - bidi_reordering = doc['base_dir'] - bounds = {'text_direction': 'horizontal-lr', - 'tags': True, - 'lines': doc.get_sorted_lines(), - 'regions': doc.get_sorted_regions(), - 'type': 'baselines', - 'image': doc.imagename} + ctx.meta['base_image'] = doc.imagename + if doc.base_dir and bidi_reordering is True: + message(f'Setting base text direction for BiDi reordering to {doc.base_dir} (from XML input file)') + bidi_reordering = doc.base_dir + bounds = doc.to_container() try: im = Image.open(ctx.meta['base_image']) except IOError as e: @@ -212,14 +208,15 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, with click.open_file(input, 'r') as fp: try: fp = cast(IO[Any], fp) - bounds = json.load(fp) + bounds = Segmentation(**json.load(fp)) except ValueError as e: raise click.UsageError(f'{input} invalid segmentation: {str(e)}') elif not bounds: if no_segmentation: - bounds = {'script_detection': False, - 'text_direction': 'horizontal-lr', - 'boxes': [(0, 0) + im.size]} + bounds = Segmentation(type='bbox', + text_direction='horizontal-lr', + lines=[BBoxLine(id=uuid.uuid4(), + bbox=((0, 0), (0, im.size[1]), im.size, (im.size[0], 0)))]) else: raise click.UsageError('No line segmentation given. Add one with the input or run `segment` first.') elif no_segmentation: @@ -227,7 +224,7 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, tags = set() # script detection - if 'script_detection' in bounds and bounds['script_detection']: + if bounds.script_detection: it = rpred.mm_rpred(model, im, bounds, pad, bidi_reordering=bidi_reordering, tags_ignore=tags_ignore) @@ -255,7 +252,7 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, image_size=Image.open(ctx.meta['base_image']).size, writing_mode=ctx.meta['text_direction'], scripts=tags, - regions=bounds['regions'] if 'regions' in bounds else None, + regions=bounds.regions, template=ctx.meta['output_template'], template_source='custom' if ctx.meta['output_mode'] == 'template' else 'native', processing_steps=ctx.meta['steps'])) diff --git a/kraken/rpred.py b/kraken/rpred.py index 3354b56b6..412b84c6e 100644 --- a/kraken/rpred.py +++ b/kraken/rpred.py @@ -19,6 +19,7 @@ Generators for recognition on lines images. """ import logging +import dataclasses import numpy as np import bidi.algorithm as bd @@ -47,12 +48,12 @@ class mm_rpred(object): Multi-model version of kraken.rpred.rpred """ def __init__(self, - nets: Dict[str, TorchSeqRecognizer], + nets: Dict[Tuple[str, str], TorchSeqRecognizer], im: Image.Image, bounds: Segmentation, pad: int = 16, bidi_reordering: Union[bool, str] = True, - tags_ignore: Optional[List[str]] = None) -> Generator[ocr_record, None, None]: + tags_ignore: Optional[List[Tuple[str, str]]] = None) -> Generator[ocr_record, None, None]: """ Multi-model version of kraken.rpred.rpred. @@ -61,8 +62,8 @@ def __init__(self, these lines. Args: - nets: A dict mapping tag values to TorchSegRecognizer objects. - Recommended to be an defaultdict. + nets: A dict mapping tag key-value pairs to TorchSegRecognizer + objects. Recommended to be an defaultdict. im: Image to extract text from bounds: A Segmentation data class containing either bounding box or baseline type segmentation. @@ -71,7 +72,8 @@ def __init__(self, Unicode bidirectional algorithm for correct display. Set to L|R to override default text direction. - tags_ignore: List of tag values to ignore during recognition + tags_ignore: List of tag key-value pairs to ignore during + recognition Yields: An ocr_record containing the recognized text, absolute character @@ -115,7 +117,7 @@ def __init__(self, tags = set() for x in bounds.lines: - tags.update(x.tags.values()) + tags.update(x.tags.items()) im_str = get_im_str(im) logger.info(f'Running {len(nets)} multi-script recognizers on {im_str} with {self.len} lines') @@ -231,8 +233,10 @@ def _recognize_baseline_line(self, line): logger.info(f'Ignoring line segment with tags {line.tags} based on {tag}.') return BaselineOCRRecord('', [], [], line) + seg = dataclasses.replace(self.bounds, lines=[line]) + try: - box, coords = next(extract_polygons(self.im, line)) + box, coords = next(extract_polygons(self.im, seg)) except KrakenInputException as e: logger.warning(f'Extracting line failed: {e}') return BaselineOCRRecord('', [], [], line) @@ -279,7 +283,6 @@ def _recognize_baseline_line(self, line): return rec.display_order(None) def __next__(self): - bound = self.bounds return self.next_iter(next(self.line_iter)) def __iter__(self): @@ -319,12 +322,12 @@ def rpred(network: TorchSeqRecognizer, def _resolve_tags_to_model(tags: Sequence[Dict[str, str]], - model_map: Dict[str, TorchSeqRecognizer], + model_map: Dict[Tuple[str, str], TorchSeqRecognizer], default: Optional[TorchSeqRecognizer] = None) -> TorchSeqRecognizer: """ Resolves a sequence of tags """ - for tag in tags.values(): + for tag in tags.items(): if tag in model_map: return tag, model_map[tag] if default: From 3aa34a893dba202325349fab2960fdacf018cad6 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Tue, 18 Jul 2023 13:02:09 +0200 Subject: [PATCH 54/68] serialization with new container classes --- kraken/align.py | 25 +++--- kraken/containers.py | 40 +++++++-- kraken/kraken.py | 6 +- kraken/lib/segmentation.py | 3 +- kraken/pageseg.py | 5 +- kraken/rpred.py | 6 +- kraken/serialization.py | 167 ++++++++++++++----------------------- kraken/templates/alto | 9 +- 8 files changed, 121 insertions(+), 140 deletions(-) diff --git a/kraken/align.py b/kraken/align.py index 961585221..2c20cc790 100644 --- a/kraken/align.py +++ b/kraken/align.py @@ -23,6 +23,7 @@ """ import torch import logging +import dataclasses import numpy as np from PIL import Image @@ -31,8 +32,9 @@ from dataclasses import dataclass from typing import List, Dict, Any, Optional, Literal -from kraken import rpred +from kraken import rpred, containers from kraken.lib.codec import PytorchCodec +from kraken.lib.xml import XMLPage from kraken.lib.models import TorchSeqRecognizer from kraken.lib.exceptions import KrakenInputException, KrakenEncodeException from kraken.lib.segmentation import compute_polygon_section @@ -40,7 +42,7 @@ logger = logging.getLogger('kraken') -def forced_align(doc: Dict[str, Any], model: TorchSeqRecognizer, base_dir: Optional[Literal['L', 'R']] = None) -> List[rpred.ocr_record]: +def forced_align(doc: Segmentation, model: TorchSeqRecognizer, base_dir: Optional[Literal['L', 'R']] = None) -> containers.Segmentation: """ Performs a forced character alignment of text with recognition model output activations. @@ -50,28 +52,26 @@ def forced_align(doc: Dict[str, Any], model: TorchSeqRecognizer, base_dir: Optio model: Recognition model to use for alignment. Returns: - A list of kraken.rpred.ocr_record. + A Segmentation object where the record's contain the aligned text. """ - im = Image.open(doc['image']) + im = Image.open(doc.imagename) predictor = rpred.rpred(model, im, doc) - if 'type' in predictor.bounds and predictor.bounds['type'] == 'baselines': - rec_class = rpred.BaselineOCRRecord records = [] # enable training mode in last layer to get log_softmax output model.nn.nn[-1].training = True - for idx, line in enumerate(doc['lines']): + for idx, line in enumerate(doc.lines): # convert text to display order - do_text = get_display(line['text'], base_dir=base_dir) + do_text = get_display(line.text, base_dir=base_dir) # encode into labels, ignoring unencodable sequences labels = model.codec.encode(do_text).long() next(predictor) if model.outputs.shape[2] < 2*len(labels): logger.warning(f'Could not align line {idx}. Output sequence length {model.outputs.shape[2]} < ' - f'{2*len(labels)} (length of "{line["text"]}" after encoding).') - records.append(rpred.BaselineOCRRecord('', [], [], line)) + f'{2*len(labels)} (length of "{line.text}" after encoding).') + records.append(containers.BaselineOCRRecord('', [], [], line)) continue emission = torch.tensor(model.outputs).squeeze().T trellis = get_trellis(emission, labels) @@ -85,8 +85,9 @@ def forced_align(doc: Dict[str, Any], model: TorchSeqRecognizer, base_dir: Optio pos.append((predictor._scale_val(seg.start, 0, predictor.box.size[0]), predictor._scale_val(seg.end, 0, predictor.box.size[0]))) conf.append(seg.score) - records.append(rpred.BaselineOCRRecord(pred, pos, conf, line, display_order=True)) - return records + records.append(containers.BaselineOCRRecord(pred, pos, conf, line, display_order=True)) + return dataclasses.replace(doc, lines=records) + """ Copied from the forced alignment with Wav2Vec2 tutorial of pytorch available diff --git a/kraken/containers.py b/kraken/containers.py index 5218c39c3..29a75c522 100644 --- a/kraken/containers.py +++ b/kraken/containers.py @@ -1,17 +1,33 @@ import PIL.Image +import numpy as np import bidi.algorithm as bd +from os import PathLike from typing import Literal, List, Dict, Union, Optional, Tuple from dataclasses import dataclass, asdict from abc import ABC, abstractmethod +from kraken.lib.segmentation import compute_polygon_section + __all__ = ['BaselineLine', 'BBoxLine', 'Segmentation', 'ocr_record', 'BaselineOCRRecord', - 'BBoxOCRRecord'] + 'BBoxOCRRecord', + 'ProcessingStep'] + + +@dataclass +class ProcessingStep: + """ + A processing step in the recognition pipeline. + """ + id: str + category: Literal['preprocessing', 'processing', 'postprocessing'] + description: str + settings: Dict[str, Union[Dict, str, float, int, bool]] @dataclass @@ -29,8 +45,6 @@ class BaselineLine: split: Optional[Literal['train', 'validation', 'test']] = None regions: Optional[List[str]] = None - def __getitem__(self, item): - return getattr(self, item) @dataclass class BBoxLine: @@ -50,6 +64,7 @@ class BBoxLine: regions: Optional[List[str]] = None text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] = 'horizontal-lr' + @dataclass class Region: """ @@ -64,10 +79,14 @@ class Region: @dataclass class Segmentation: """ + A container class for segmentation or recognition results. + In order to allow easy JSON de-/serialization, nested classes for lines + (BaselineLine/BBoxLine) and regions (Region) are reinstantiated from their + dictionaries. """ type: Literal['baselines', 'bbox'] - imagename: str + imagename: Union[str, PathLike] text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] script_detection: bool lines: List[Union[BaselineLine, BBoxLine]] @@ -78,6 +97,11 @@ def __post_init__(self): if len(self.lines) and not isinstance(self.lines[0], BBoxLine) and not isinstance(self.lines[0], BaselineLine): line_cls = BBoxLine if self.type == 'bbox' else BaselineLine self.lines = [line_cls(**line) for line in self.lines] + if len(self.regions) and not isinstance(next(iter(self.regions.values()))[0], Region): + regs = {} + for k, v in self.regions.items(): + regs[k] = [Region(**reg) for reg in v] + self.regions = regs class ocr_record(ABC): @@ -195,7 +219,7 @@ def __next__(self) -> Tuple[str, int, float]: self.idx += 1 return (self.prediction[self.idx], compute_polygon_section(self.baseline, - self.line, + self.boundary, self.cuts[self.idx][0], self.cuts[self.idx][1]), self.confidences[self.idx]) @@ -217,7 +241,7 @@ def __getitem__(self, key: Union[int, slice]): prediction = ''.join([x[0] for x in recs]) flat_offsets = sum((tuple(x[1]) for x in recs), ()) cut = compute_polygon_section(self.baseline, - self.line, + self.boundary, min(flat_offsets), max(flat_offsets)) confidence = np.mean([x[2] for x in recs]) @@ -225,14 +249,14 @@ def __getitem__(self, key: Union[int, slice]): elif isinstance(key, int): pred, cut, confidence = self._get_raw_item(key) return (pred, - compute_polygon_section(self.baseline, self.line, cut[0], cut[1]), + compute_polygon_section(self.baseline, self.boundary, cut[0], cut[1]), confidence) else: raise TypeError('Invalid argument type') @property def cuts(self) -> List[Tuple[int, int]]: - return tuple([compute_polygon_section(self.baseline, self.line, cut[0], cut[1]) for cut in self._cuts]) + return tuple([compute_polygon_section(self.baseline, self.boundary, cut[0], cut[1]) for cut in self._cuts]) def logical_order(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BaselineOCRRecord': """ diff --git a/kraken/kraken.py b/kraken/kraken.py index cb76614d5..0f8fea361 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -181,6 +181,7 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, import json import uuid + import dataclasses from kraken.containers import Segmentation, BBoxLine from kraken import rpred @@ -239,6 +240,7 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, for pred in it: preds.append(pred) progress.update(pred_task, advance=1) + results = dataclasses.replace(it.bounds, lines=preds, imagename=ctx.meta['base_image']) ctx = click.get_current_context() with click.open_file(output, 'w', encoding='utf-8') as fp: @@ -247,12 +249,10 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, logger.info('Serializing as {} into {}'.format(ctx.meta['output_mode'], output)) if ctx.meta['output_mode'] != 'native': from kraken import serialization - fp.write(serialization.serialize(records=preds, - image_name=ctx.meta['base_image'], + fp.write(serialization.serialize(results=results, image_size=Image.open(ctx.meta['base_image']).size, writing_mode=ctx.meta['text_direction'], scripts=tags, - regions=bounds.regions, template=ctx.meta['output_template'], template_source='custom' if ctx.meta['output_mode'] == 'template' else 'native', processing_steps=ctx.meta['steps'])) diff --git a/kraken/lib/segmentation.py b/kraken/lib/segmentation.py index ef07142d3..215c4e5d4 100644 --- a/kraken/lib/segmentation.py +++ b/kraken/lib/segmentation.py @@ -42,7 +42,6 @@ from typing import List, Tuple, Union, Dict, Any, Sequence, Optional, Literal -from kraken.containers import Segmentation, BaselineLine, BBoxLine from kraken.lib import default_specs from kraken.lib.exceptions import KrakenInputException @@ -1033,7 +1032,7 @@ def compute_polygon_section(baseline: Sequence[Tuple[int, int]], return tuple(o) -def extract_polygons(im: Image.Image, bounds: Segmentation) -> Image.Image: +def extract_polygons(im: Image.Image, bounds: 'kraken.containers.Segmentation') -> Image.Image: """ Yields the subimages of image im defined in the list of bounding polygons with baselines preserving order. diff --git a/kraken/pageseg.py b/kraken/pageseg.py index 2fcd8f823..88c6a793c 100644 --- a/kraken/pageseg.py +++ b/kraken/pageseg.py @@ -27,11 +27,12 @@ from scipy.ndimage.filters import (gaussian_filter, uniform_filter, maximum_filter) +from kraken.containers import Segmentation, BBoxLine + from kraken.lib import morph, sl from kraken.lib.util import pil2array, is_bitonal, get_im_str from kraken.lib.exceptions import KrakenInputException -from kraken.lib.segmentation import reading_order, topsort, Segmentation, BBoxLine - +from kraken.lib.segmentation import reading_order, topsort __all__ = ['segment'] diff --git a/kraken/rpred.py b/kraken/rpred.py index 412b84c6e..960f4a370 100644 --- a/kraken/rpred.py +++ b/kraken/rpred.py @@ -21,18 +21,16 @@ import logging import dataclasses import numpy as np -import bidi.algorithm as bd -from abc import ABC, abstractmethod from PIL import Image from functools import partial from collections import defaultdict from typing import List, Tuple, Optional, Generator, Union, Dict, Sequence -from kraken.containers import BaselineOCRRecord, BBoxOCRRecord, ocr_record +from kraken.containers import BaselineOCRRecord, BBoxOCRRecord, ocr_record, Segmentation from kraken.lib.util import get_im_str, is_bitonal from kraken.lib.models import TorchSeqRecognizer -from kraken.lib.segmentation import extract_polygons, compute_polygon_section, Segmentation +from kraken.lib.segmentation import extract_polygons from kraken.lib.exceptions import KrakenInputException from kraken.lib.dataset import ImageInputTransforms diff --git a/kraken/serialization.py b/kraken/serialization.py index da644c948..79a25084b 100644 --- a/kraken/serialization.py +++ b/kraken/serialization.py @@ -23,9 +23,9 @@ from pkg_resources import get_distribution from collections import Counter -from kraken.rpred import BaselineOCRRecord, BBoxOCRRecord, ocr_record +from kraken.containers import Segmentation, ProcessingStep from kraken.lib.util import make_printable -from kraken.lib.segmentation import is_in_region, Segmentation +from kraken.lib.segmentation import is_in_region from typing import Union, List, Tuple, Iterable, Optional, Sequence, Dict, Any, Literal @@ -70,95 +70,87 @@ def max_bbox(boxes: Iterable[Sequence[int]]) -> Tuple[int, int, int, int]: return o -def serialize(records: Sequence[ocr_record], - image_name: Union[PathLike, str] = None, +def serialize(results: Segmentation, image_size: Tuple[int, int] = (0, 0), writing_mode: Literal['horizontal-tb', 'vertical-lr', 'vertical-rl'] = 'horizontal-tb', scripts: Optional[Iterable[str]] = None, - regions: Optional[Dict[str, List[List[Tuple[int, int]]]]] = None, template: [PathLike, str] = 'alto', template_source: Literal['native', 'custom'] = 'native', - processing_steps: Optional[List[Dict[str, Union[Dict, str, float, int, bool]]]] = None) -> str: + processing_steps: Optional[List[ProcessingStep]] = None) -> str: """ - Serializes a list of ocr_records into an output document. + Serializes recognition and segmentation results into an output document. - Serializes a list of predictions and their corresponding positions by doing - some hOCR-specific preprocessing and then renders them through one of - several jinja2 templates. + Serializes a Segmentation container object containing either segmentation + or recognition results into an output document. The rendering is performed + with jinja2 templates that can either be shipped with kraken + (`template_source` == 'native') or custom (`template_source` == 'custom'). Note: Empty records are ignored for serialization purposes. Args: - records: List of kraken.rpred.ocr_record - image_name: Name of the source image + segmentation: Segmentation container object image_size: Dimensions of the source image writing_mode: Sets the principal layout of lines and the direction in which blocks progress. Valid values are horizontal-tb, vertical-rl, and vertical-lr. scripts: List of scripts contained in the OCR records - regions: Dictionary mapping region types to a list of region polygons. template: Selector for the serialization format. May be 'hocr', 'alto', 'page' or any template found in the template directory. If template_source is set to `custom` a path to a template is expected. template_source: Switch to enable loading of custom templates from outside the kraken package. - processing_steps: A list of dictionaries describing the processing kraken performed on the inputs:: - - {'category': 'preprocessing', - 'description': 'natural language description of process', - 'settings': {'arg0': 'foo', 'argX': 'bar'} - } + processing_steps: A list of ProcessingStep container classes describing + the processing kraken performed on the inputs. Returns: The rendered template """ - logger.info(f'Serialize {len(records)} records from {image_name} with template {template}.') + logger.info(f'Serialize {len(results.lines)} records from {results.imagename} with template {template}.') page = {'entities': [], 'size': image_size, - 'name': image_name, + 'name': results.imagename, 'writing_mode': writing_mode, 'scripts': scripts, 'date': datetime.datetime.now(datetime.timezone.utc).isoformat(), - 'base_dir': [rec.base_dir for rec in records][0] if len(records) else None} # type: dict + 'base_dir': [rec.base_dir for rec in results.lines][0] if len(results.lines) else None, + 'line_orders': results.line_orders, + 'seg_type': results.type} # type: dict metadata = {'processing_steps': processing_steps, 'version': get_distribution('kraken').version} seg_idx = 0 char_idx = 0 - region_map = {} - idx = 0 - if regions is not None: - for id, regs in regions.items(): - for reg in regs: - region_map[idx] = (id, geom.Polygon(reg), reg) - idx += 1 # build region and line type dict types = [] - for line in records: - if hasattr(line, 'tags') and line.tags is not None: - types.extend(line.tags.values()) - page['types'] = list(set(types)) - if regions is not None: - page['types'].extend(list(regions.keys())) - - is_in_reg = -1 - for idx, record in enumerate(records): - if record.type == 'baselines': - l_obj = geom.LineString(record.baseline) - else: - l_obj = geom.LineString(record.line) - reg = list(filter(lambda x: is_in_region(l_obj, x[1][1]), region_map.items())) - if len(reg) == 0: - cur_ent = page['entities'] - elif reg[0][0] != is_in_reg: - reg = reg[0] - is_in_reg = reg[0] - region = {'index': reg[0], - 'bbox': [int(x) for x in reg[1][1].bounds], - 'boundary': [list(x) for x in reg[1][2]], - 'region_type': reg[1][0], + for line in results.lines: + if line.tags is not None: + types.extend((k, v) for k, v in line.tags.items()) + page['line_types'] = list(set(types)) + page['region_types'] =[list(results.regions.keys())] + + # build region ID to region dict + reg_dict = {} + for key, regs in results.regions.items(): + for reg in regs: + reg_dict[reg.id] = reg + + regs_with_lines = set() + prev_reg = None + for idx, record in enumerate(results.lines): + # line not in region + if len(record.regions) == 0: + cur_ent = page['entitites'] + # line not in same region as previous line + elif prev_reg != record.regions[0]: + prev_reg = record.regions[0] + reg = reg_dict[record.regions[0]] + regs_with_lines.add(reg.id) + region = {'id': reg.id, + 'bbox': max_bbox([reg.boundary]), + 'boundary': [list(x) for x in reg.boundary], + 'tags': reg.tags, 'lines': [], 'type': 'region' } @@ -167,20 +159,19 @@ def serialize(records: Sequence[ocr_record], # set field to indicate the availability of baseline segmentation in # addition to bounding boxes - if record.type == 'baselines': - page['seg_type'] = 'baselines' line = {'index': idx, - 'bbox': max_bbox([record.line]), + 'bbox': max_bbox([record.boundary] if record.type == 'baselines' else record.bbox), 'cuts': record.cuts, 'confidences': record.confidences, 'recognition': [], - 'boundary': [list(x) for x in record.line], + 'boundary': [list(x) for x in record.boundary], 'type': 'line' } - if hasattr(record, 'tags') and record.tags is not None: + if record.tags is not None: line['tags'] = record.tags if record.type == 'baselines': line['baseline'] = [list(x) for x in record.baseline] + splits = regex.split(r'(\s+)', record.prediction) line_offset = 0 logger.debug(f'Record contains {len(splits)} segments') @@ -213,18 +204,19 @@ def serialize(records: Sequence[ocr_record], line_offset += len(segment) cur_ent.append(line) - # No records but there are regions -> serialize all regions - if not records and regions: - logger.debug(f'No lines given but {len(region_map)}. Serialize all regions.') - for reg in region_map.items(): - region = {'index': reg[0], - 'bbox': [int(x) for x in reg[1][1].bounds], - 'boundary': [list(x) for x in reg[1][2]], - 'region_type': reg[1][0], - 'lines': [], - 'type': 'region' - } - page['entities'].append(region) + # serialize all remaining (line-less) regions + for reg_id in regs_with_lines: + reg_dict.pop(reg_id) + logger.debug(f'No lines given but {len(results.regions)}. Serialize all regions.') + for reg in reg_dict.values(): + region = {'id': reg.id, + 'bbox': max_bbox([reg.boundary]), + 'boundary': [list(x) for x in reg.boundary], + 'tags': reg.tags, + 'lines': [], + 'type': 'region' + } + page['entities'].append(region) if template_source == 'native': logger.debug('Initializing native jinja environment.') @@ -246,43 +238,6 @@ def _load_template(name): return tmpl.render(page=page, metadata=metadata) -def serialize_segmentation(segresult: Segmentation, - image_name: Union[PathLike, str] = None, - image_size: Tuple[int, int] = (0, 0), - template: Union[PathLike, str] = 'alto', - template_source: Literal['native', 'custom'] = 'native', - processing_steps: Optional[List[Dict[str, Union[Dict, str, float, int, bool]]]] = None) -> str: - """ - Serializes a segmentation result into an output document. - - Args: - segresult: Result of blla.segment - image_name: Name of the source image - image_size: Dimensions of the source image - template: Selector for the serialization format. Any value accepted by - `serialize` is valid. - template_source: Enables/disables loading of external templates. - - Returns: - (str) rendered template. - """ - if segresult.type == 'baselines': - records = [BaselineOCRRecord('', (), (), bl) for bl in segresult.lines] - else: - records = [] - for line in segresult.lines: - xmin, xmax = min(line[::2]), max(line[::2]) - ymin, ymax = min(line[1::2]), max(line[1::2]) - records.append(BBoxOCRRecord('', (), (), ((xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)))) - return serialize(records, - image_name=image_name, - image_size=image_size, - regions=segresult.regions, - template=template, - template_source=template_source, - processing_steps=processing_steps) - - def render_report(model: str, chars: int, errors: int, diff --git a/kraken/templates/alto b/kraken/templates/alto index 05d7ab192..3c5f34e99 100644 --- a/kraken/templates/alto +++ b/kraken/templates/alto @@ -49,7 +49,7 @@ {% if metadata.processing_steps %} {% for step in metadata.processing_steps %} - + {{ proc_type_table[step.category] }} {{ step.description }} {% for k, v in step.settings.items() %}{{k}}: {{v}}; {% endfor %} @@ -71,8 +71,11 @@ {% endif %} - {% for reg_type in page.types %} - + {% for type, label in page.line_types %} + + {% endfor %} + {% for label in page.region_types %} + {% endfor %} From 4eef3c57b864cb2ab83c33760a7db6c6cc7b7d19 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Wed, 19 Jul 2023 01:20:05 +0200 Subject: [PATCH 55/68] Add alternative reading orders to ALTO output --- kraken/serialization.py | 7 ++++++- kraken/templates/alto | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/kraken/serialization.py b/kraken/serialization.py index 79a25084b..ae6679762 100644 --- a/kraken/serialization.py +++ b/kraken/serialization.py @@ -114,7 +114,6 @@ def serialize(results: Segmentation, 'scripts': scripts, 'date': datetime.datetime.now(datetime.timezone.utc).isoformat(), 'base_dir': [rec.base_dir for rec in results.lines][0] if len(results.lines) else None, - 'line_orders': results.line_orders, 'seg_type': results.type} # type: dict metadata = {'processing_steps': processing_steps, 'version': get_distribution('kraken').version} @@ -130,6 +129,12 @@ def serialize(results: Segmentation, page['line_types'] = list(set(types)) page['region_types'] =[list(results.regions.keys())] + # map reading orders indices to line IDs + ros = [] + for ro in results.line_orders: + ros.append([results.lines[idx].id for idx in ro]) + page['line_orders'] = ros + # build region ID to region dict reg_dict = {} for key, regs in results.regions.items(): diff --git a/kraken/templates/alto b/kraken/templates/alto index 3c5f34e99..ccddf4182 100644 --- a/kraken/templates/alto +++ b/kraken/templates/alto @@ -78,6 +78,27 @@ {% endfor %} + {% if len(page.line_orders) > 0 %} + + {% if len(page.line_orders) == 1 %} + + {% for id in page.line_orders[0] %} + + {% endfor %} + + {% else %} + + {% for ro in page.line_orders %} + + {% for id in ro %} + + {% endfor %} + + {% endfor %} + + {% endif %} + + {% endif %} From c08edcb76effefb4d96002e5ada6092321bc7b62 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Wed, 19 Jul 2023 12:31:55 +0200 Subject: [PATCH 56/68] docstrings --- kraken/align.py | 9 ++-- kraken/blla.py | 7 +-- kraken/containers.py | 120 ++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 123 insertions(+), 13 deletions(-) diff --git a/kraken/align.py b/kraken/align.py index 2c20cc790..ff76a1d43 100644 --- a/kraken/align.py +++ b/kraken/align.py @@ -32,7 +32,8 @@ from dataclasses import dataclass from typing import List, Dict, Any, Optional, Literal -from kraken import rpred, containers +from kraken import rpred +from kraken.containers import Segmentation, BaselineOCRRecord from kraken.lib.codec import PytorchCodec from kraken.lib.xml import XMLPage from kraken.lib.models import TorchSeqRecognizer @@ -42,7 +43,7 @@ logger = logging.getLogger('kraken') -def forced_align(doc: Segmentation, model: TorchSeqRecognizer, base_dir: Optional[Literal['L', 'R']] = None) -> containers.Segmentation: +def forced_align(doc: Segmentation, model: TorchSeqRecognizer, base_dir: Optional[Literal['L', 'R']] = None) -> Segmentation: """ Performs a forced character alignment of text with recognition model output activations. @@ -71,7 +72,7 @@ def forced_align(doc: Segmentation, model: TorchSeqRecognizer, base_dir: Optiona if model.outputs.shape[2] < 2*len(labels): logger.warning(f'Could not align line {idx}. Output sequence length {model.outputs.shape[2]} < ' f'{2*len(labels)} (length of "{line.text}" after encoding).') - records.append(containers.BaselineOCRRecord('', [], [], line)) + records.append(BaselineOCRRecord('', [], [], line)) continue emission = torch.tensor(model.outputs).squeeze().T trellis = get_trellis(emission, labels) @@ -85,7 +86,7 @@ def forced_align(doc: Segmentation, model: TorchSeqRecognizer, base_dir: Optiona pos.append((predictor._scale_val(seg.start, 0, predictor.box.size[0]), predictor._scale_val(seg.end, 0, predictor.box.size[0]))) conf.append(seg.score) - records.append(containers.BaselineOCRRecord(pred, pos, conf, line, display_order=True)) + records.append(BaselineOCRRecord(pred, pos, conf, line, display_order=True)) return dataclasses.replace(doc, lines=records) diff --git a/kraken/blla.py b/kraken/blla.py index 6d65de794..050a2abb7 100644 --- a/kraken/blla.py +++ b/kraken/blla.py @@ -276,9 +276,10 @@ def segment(im: PIL.Image.Image, autocast: Runs the model with automatic mixed precision Returns: - A :class:`kraken.lib.blla.Segmentation` class containing reading order - sorted baselines (polylines) and their respective polygonal boundaries. - The format of the line and region records is shown below. The last and + A :class:`kraken.containers.Segmentation` class containing reading order + sorted baselines (polylines) and their respective polygonal boundaries + as :class:`kraken.containers.BaselineLine` records. + The format of the line and region records is shown below. The last and first point of each boundary polygon are connected. .. code-block:: diff --git a/kraken/containers.py b/kraken/containers.py index 29a75c522..892d00109 100644 --- a/kraken/containers.py +++ b/kraken/containers.py @@ -23,6 +23,12 @@ class ProcessingStep: """ A processing step in the recognition pipeline. + + Attributes: + id: Unique identifier + category: Category of processing step that has been performat. + description: Natural-language description of the process. + settings: Dict describing the parameters of the processing step. """ id: str category: Literal['preprocessing', 'processing', 'postprocessing'] @@ -33,6 +39,26 @@ class ProcessingStep: @dataclass class BaselineLine: """ + Baseline-type line record. + + A container class for a single line in baseline + bounding polygon format, + optionally containing a transcription, tags, or associated regions. + + Attributes: + id: Unique identifier + baseline: List of tuples `(x_n, y_n)` defining the baseline. + boundary: List of tuples `(x_n, y_n)` defining the bounding polygon of + the line. The first and last points should be identical. + text: Transcription of this line. + base_dir: An optional string defining the base direction (also called + paragraph direction) for the BiDi algorithm. Valid values are + 'L' or 'R'. If None is given the default auto-resolution will + be used. + image: Image object of this line. + tags: A dict mapping types to values. + split: Defines whether this line is in the `train`, `validation`, or + `test` set during training. + regions: A list of identifiers of regions the line is associated with. """ id: str baseline: List[Tuple[int, int]] @@ -49,6 +75,27 @@ class BaselineLine: @dataclass class BBoxLine: """ + Bounding box-type line record. + + A container class for a single line in axis-aligned bounding box format, + optionally containing a transcription, tags, or associated regions. + + Attributes: + id: Unique identifier + bbox: Tuple in form `((x0, y0), (x1, y0), (x1, y1), (x0, y1))` defining + the bounding box. + text: Transcription of this line. + base_dir: An optional string defining the base direction (also called + paragraph direction) for the BiDi algorithm. Valid values are + 'L' or 'R'. If None is given the default auto-resolution will + be used. + image: Image object of this line. + tags: A dict mapping types to values. + split: Defines whether this line is in the `train`, `validation`, or + `test` set during training. + regions: A list of identifiers of regions the line is associated with. + text_direction: Sets the principal orientation (of the line) and + reading direction (of the document). """ id: str bbox: Tuple[Tuple[int, int], @@ -68,7 +115,14 @@ class BBoxLine: @dataclass class Region: """ + Container class of a single polygonal region. + Attributes: + id: Unique identifier + boundary: List of tuples `(x_n, y_n)` defining the bounding polygon of + the region. The first and last points should be identical. + image: Image object containing the region. + tags: A dict mapping types to values. """ id: str boundary: List[Tuple[int, int]] @@ -84,6 +138,22 @@ class Segmentation: In order to allow easy JSON de-/serialization, nested classes for lines (BaselineLine/BBoxLine) and regions (Region) are reinstantiated from their dictionaries. + + Attributes: + type: Field indicating if baselines + (:class:`kraken.containers.BaselineLine`) or bbox + (:class:`kraken.containers.BBoxLine`) line records are in the + segmentation. + imagename: Path to the image associated with the segmentation. + text_direction: Sets the principal orientation (of the line), i.e. + horizontal/vertical, and reading direction (of the + document), i.e. lr/rl. + script_detection: Flag indicating if the line records have tags. + lines: List of line records. Records are expected to be in a valid + reading order. + regions: Dict mapping types to lists of regions. + line_orders: List of alternative reading orders for the segmentation. + Each reading order is a list of line indices. """ type: Literal['baselines', 'bbox'] imagename: Union[str, PathLike] @@ -185,6 +255,16 @@ class BaselineOCRRecord(ocr_record, BaselineLine): as a list of tuples [(x0, y0), (x1, y2), ...]. confidences: A list of floats indicating the confidence value of each code point. + base_dir: An optional string defining the base direction (also called + paragraph direction) for the BiDi algorithm. Valid values are + 'L' or 'R'. If None is given the default auto-resolution will + be used. + display_order: Flag indicating the order of the code points in the + prediction. In display order (`True`) the n-th code + point in the string corresponds to the n-th leftmost + code point, in logical order (`False`) the n-th code + point corresponds to the n-th read code point. See [UAX + #9](https://unicode.org/reports/tr9) for more details. Notes: When slicing the record the behavior of the cuts is changed from @@ -342,6 +422,34 @@ class BBoxOCRRecord(ocr_record, BBoxLine): """ A record object containing the recognition result of a single line in bbox format. + + Attributes: + type: 'bbox' to indicate a bounding box record + prediction: The text predicted by the network as one continuous string. + cuts: The absolute bounding polygons for each code point in prediction + as a list of 4-tuples `((x0, y0), (x1, y0), (x1, y1), (x0, y1))`. + confidences: A list of floats indicating the confidence value of each + code point. + base_dir: An optional string defining the base direction (also called + paragraph direction) for the BiDi algorithm. Valid values are + 'L' or 'R'. If None is given the default auto-resolution will + be used. + display_order: Flag indicating the order of the code points in the + prediction. In display order (`True`) the n-th code + point in the string corresponds to the n-th leftmost + code point, in logical order (`False`) the n-th code + point corresponds to the n-th read code point. See [UAX + #9](https://unicode.org/reports/tr9) for more details. + + Notes: + When slicing the record the behavior of the cuts is changed from + earlier versions of kraken. Instead of returning per-character bounding + polygons a single polygons section of the line bounding polygon + starting at the first and extending to the last code point emitted by + the network is returned. This aids numerical stability when computing + aggregated bounding polygons such as for words. Individual code point + bounding polygons are still accessible through the `cuts` attribute or + by iterating over the record code point by code point. """ type = 'bbox' @@ -468,13 +576,13 @@ def _reorder(self, base_dir: Optional[Literal['L', 'R']] = None) -> 'BBoxOCRReco image=self.image, tags=self.tags, split=self.split, - region=self.region) + regions=self.regions) rec = BBoxOCRRecord(prediction=prediction, - cuts=cuts, - confidences=confidences, - line=line, - base_dir=base_dir, - display_order=not self._display_order) + cuts=cuts, + confidences=confidences, + line=line, + base_dir=base_dir, + display_order=not self._display_order) return rec From 6dcc004ed4551fece62c40f43916696f07efca19 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Wed, 19 Jul 2023 13:01:58 +0200 Subject: [PATCH 57/68] more docstrings --- kraken/linegen.py | 4 +--- kraken/pageseg.py | 14 ++++++-------- kraken/rpred.py | 3 ++- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/kraken/linegen.py b/kraken/linegen.py index d53daf117..6ea15aa8d 100644 --- a/kraken/linegen.py +++ b/kraken/linegen.py @@ -38,8 +38,6 @@ from scipy.ndimage.interpolation import affine_transform, geometric_transform from PIL import Image, ImageOps -from typing import AnyStr - import logging import ctypes import ctypes.util @@ -104,7 +102,7 @@ class ensureBytes(object): bytes. """ @classmethod - def from_param(cls, value: AnyStr) -> bytes: + def from_param(cls, value: str) -> bytes: if isinstance(value, bytes): return value else: diff --git a/kraken/pageseg.py b/kraken/pageseg.py index 88c6a793c..b98b175b7 100644 --- a/kraken/pageseg.py +++ b/kraken/pageseg.py @@ -19,6 +19,7 @@ Layout analysis methods. """ +import PIL import uuid import logging import numpy as np @@ -303,7 +304,7 @@ def rotate_lines(lines: np.ndarray, angle: float, offset: int) -> np.ndarray: return np.column_stack((x.flatten(), y.flatten())).reshape(-1, 4) -def segment(im, +def segment(im: PIL.Image.Image, text_direction: str = 'horizontal-lr', scale: Optional[float] = None, maxcolseps: float = 2, @@ -311,7 +312,7 @@ def segment(im, no_hlines: bool = True, pad: Union[int, Tuple[int, int]] = 0, mask: Optional[np.ndarray] = None, - reading_order_fn: Callable = reading_order) -> Dict[str, Any]: + reading_order_fn: Callable = reading_order) -> Segmentation: """ Segments a page into text lines. @@ -338,12 +339,9 @@ def segment(im, direction in (`rl`, `lr`). Returns: - A dictionary containing the text direction and a list of reading order - sorted bounding boxes under the key 'boxes': - - .. code-block:: - - {'text_direction': '$dir', 'boxes': [(x1, y1, x2, y2),...]} + A :class:`kraken.containers.Segmentation` class containing reading + order sorted bounding box-type lines as + :class:`kraken.containers.BBoxLine` records. Raises: KrakenInputException: if the input image is not binarized or the text diff --git a/kraken/rpred.py b/kraken/rpred.py index 960f4a370..823cad5ea 100644 --- a/kraken/rpred.py +++ b/kraken/rpred.py @@ -304,7 +304,8 @@ def rpred(network: TorchSeqRecognizer, Args: network: A TorchSegRecognizer object im: Image to extract text from - bounds: A Segmentation class instance containing either a baseline or bbox segmentation. + bounds: A Segmentation class instance containing either a baseline or + bbox segmentation. pad: Extra blank padding to the left and right of text line. Auto-disabled when expected network inputs are incompatible with padding. From dd1b9ed7637ef79b272649227d135d4d1d8f9d0a Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Thu, 20 Jul 2023 15:09:14 +0200 Subject: [PATCH 58/68] better default output name in ketos compile --- kraken/ketos/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kraken/ketos/dataset.py b/kraken/ketos/dataset.py index 8d83938c3..3996a2a52 100644 --- a/kraken/ketos/dataset.py +++ b/kraken/ketos/dataset.py @@ -23,7 +23,7 @@ @click.command('compile') @click.pass_context -@click.option('-o', '--output', show_default=True, type=click.Path(), default='model', help='Output model file') +@click.option('-o', '--output', show_default=True, type=click.Path(), default='dataset.arrow', help='Output dataset file') @click.option('--workers', show_default=True, default=1, help='Number of parallel workers for text line extraction.') @click.option('-f', '--format-type', type=click.Choice(['path', 'xml', 'alto', 'page']), default='xml', show_default=True, help='Sets the training data format. In ALTO and PageXML mode all ' From 10c83d33804063852eb745ed5b320bed2d111282 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Thu, 20 Jul 2023 15:09:33 +0200 Subject: [PATCH 59/68] fix import in arrow_dataset --- kraken/lib/arrow_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kraken/lib/arrow_dataset.py b/kraken/lib/arrow_dataset.py index 694fd4279..bac1c1a0a 100755 --- a/kraken/lib/arrow_dataset.py +++ b/kraken/lib/arrow_dataset.py @@ -28,8 +28,9 @@ from collections import Counter from typing import Optional, List, Union, Callable, Tuple, Dict from multiprocessing import Pool +from kraken.containers import Segmentation from kraken.lib import functional_im_transforms as F_t -from kraken.lib.segmentation import extract_polygons, Segmentation +from kraken.lib.segmentation import extract_polygons from kraken.lib.xml import XMLPage from kraken.lib.util import is_bitonal, make_printable from kraken.lib.exceptions import KrakenInputException From e72f5ae7c7124458ac83cc96945e6cc96bddc588 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 21 Jul 2023 15:29:26 +0200 Subject: [PATCH 60/68] arrow dataset builder test skeleton --- tests/test_arrow_dataset.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/test_arrow_dataset.py diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py new file mode 100644 index 000000000..fff086c12 --- /dev/null +++ b/tests/test_arrow_dataset.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +import unittest +import json + +import kraken + +from pytest import raises +from pathlib import Path + +from kraken.lib import xml +from kraken.lib.arrow_dataset import build_binary_dataset + +thisfile = Path(__file__).resolve().parent +resources = thisfile / 'resources' + +class TestKrakenArrowCompilation(unittest.TestCase): + """ + Tests for binary datasets + """ + def setUp(self): + self.xml = resources / '170025120000003,0074.xml' + self.bls = xml.XMLPage(self.xml) + self.box_lines = [resources / '000236.png'] + + def test_build_path_dataset(self): + pass + + def test_build_xml_dataset(self): + pass + + def test_build_obj_dataset(self): + pass + + def test_build_empty_dataset(self): + pass From 260d04766e59e446b548beac510cb4c33ef6ea5a Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 24 Jul 2023 12:55:16 +0200 Subject: [PATCH 61/68] Small fixes to RO dataset class --- kraken/lib/dataset/ro.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/kraken/lib/dataset/ro.py b/kraken/lib/dataset/ro.py index e6c6bf678..3edb82b88 100644 --- a/kraken/lib/dataset/ro.py +++ b/kraken/lib/dataset/ro.py @@ -108,12 +108,12 @@ def __init__(self, files: Sequence[Union[PathLike, str]] = None, w, h = Image.open(doc.imagename).size sorted_lines = [] for line in order: - line_coords = np.array(line['baseline']) / (w, h) + line_coords = np.array(line.baseline) / (w, h) line_center = np.mean(line_coords, axis=0) cl = torch.zeros(self.num_classes, dtype=torch.float) # if class is not in class mapping default to None class (idx 0) - cl[self.class_mapping.get(line['tags']['type'], 0)] = 1 - line_data = {'type': line['tags']['type'], + cl[self.class_mapping.get(line.tags['type'], 0)] = 1 + line_data = {'type': line.tags['type'], 'features': torch.cat((cl, # one hot encoded line type torch.tensor(line_center, dtype=torch.float), # line center torch.tensor(line_coords[0, :], dtype=torch.float), # start_point coord @@ -121,9 +121,11 @@ def __init__(self, files: Sequence[Union[PathLike, str]] = None, )) } sorted_lines.append(line_data) - self.data.append(sorted_lines) - self._num_pairs += int(factorial(len(sorted_lines))/factorial(len(sorted_lines)-2)) - + if len(sorted_lines) > 1: + self.data.append(sorted_lines) + self._num_pairs += int(factorial(len(sorted_lines))/factorial(len(sorted_lines)-2)) + else: + logger.info(f'Page {doc} has less than 2 lines. Skipping') except KrakenInputException as e: logger.warning(e) continue @@ -212,12 +214,12 @@ def __init__(self, files: Sequence[Union[PathLike, str]] = None, w, h = Image.open(doc.imagename).size sorted_lines = [] for line in order: - line_coords = np.array(line['baseline']) / (w, h) + line_coords = np.array(line.baseline) / (w, h) line_center = np.mean(line_coords, axis=0) cl = torch.zeros(self.num_classes, dtype=torch.float) # if class is not in class mapping default to None class (idx 0) - cl[self.class_mapping.get(line['tags']['type'], 0)] = 1 - line_data = {'type': line['tags']['type'], + cl[self.class_mapping.get(line.tags['type'], 0)] = 1 + line_data = {'type': line.tags['type'], 'features': torch.cat((cl, # one hot encoded line type torch.tensor(line_center, dtype=torch.float), # line center torch.tensor(line_coords[0, :], dtype=torch.float), # start_point coord @@ -225,7 +227,10 @@ def __init__(self, files: Sequence[Union[PathLike, str]] = None, )) } sorted_lines.append(line_data) - self.data.append(sorted_lines) + if len(sorted_lines) > 1: + self.data.append(sorted_lines) + else: + logger.info(f'Page {doc} has less than 2 lines. Skipping') except KrakenInputException as e: logger.warning(e) continue From 7118f1eea34356a0be01bb654c04798774d5552e Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 24 Jul 2023 13:09:07 +0200 Subject: [PATCH 62/68] forced alignment contrib script with container classes --- kraken/contrib/forced_alignment_overlay.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/kraken/contrib/forced_alignment_overlay.py b/kraken/contrib/forced_alignment_overlay.py index 5e9f19eff..1dd700c1b 100755 --- a/kraken/contrib/forced_alignment_overlay.py +++ b/kraken/contrib/forced_alignment_overlay.py @@ -35,7 +35,7 @@ def _repl_alto(fname, cuts): doc = etree.parse(fp) lines = doc.findall('.//{*}TextLine') char_idx = 0 - for line, line_cuts in zip(lines, cuts): + for line, line_cuts in zip(lines, cuts.lines): idx = 0 for el in line: if el.tag.endswith('Shape'): @@ -65,7 +65,7 @@ def _repl_page(fname, cuts): with open(fname, 'rb') as fp: doc = etree.parse(fp) lines = doc.findall('.//{*}TextLine') - for line, line_cuts in zip(lines, cuts): + for line, line_cuts in zip(lines, cuts.lines): glyphs = line.findall('../{*}Glyph/{*}Coords') for glyph, cut in zip(glyphs, line_cuts): glyph.attrib['points'] = ' '.join([','.join([str(x) for x in pt]) for pt in cut]) @@ -96,34 +96,33 @@ def cli(format_type, model, output, files): from PIL import Image, ImageDraw - from kraken.lib import models, xml + from kraken.lib.xml import XMLPage + from kraken.lib import models from kraken import align if format_type == 'alto': - fn = xml.parse_alto repl_fn = _repl_alto else: - fn = xml.parse_page repl_fn = _repl_page click.echo(f'Loading model {model}') net = models.load_any(model) for doc in files: click.echo(f'Processing {doc} ', nl=False) - data = fn(doc) - im = Image.open(data['image']).convert('RGBA') - records = align.forced_align(data, net) + data = XMLPage(doc) + im = Image.open(data.imagename).convert('RGBA') + result = align.forced_align(data.to_container, net) if output == 'overlay': tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(tmp) - for record in records: + for record in result.lines: for pol in record.cuts: c = next(cmap) draw.polygon([tuple(x) for x in pol], fill=c, outline=c[:3]) base_image = Image.alpha_composite(im, tmp) base_image.save(f'high_{os.path.basename(doc)}_algn.png') else: - repl_fn(doc, records) + repl_fn(doc, result) click.secho('\u2713', fg='green') From da6c4f6f5c5b49d82c991bed9fb63cdfe5b236ac Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 24 Jul 2023 13:13:15 +0200 Subject: [PATCH 63/68] extract_lines.py with container classes --- kraken/contrib/extract_lines.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/kraken/contrib/extract_lines.py b/kraken/contrib/extract_lines.py index b00db3a43..e2a41b8ba 100755 --- a/kraken/contrib/extract_lines.py +++ b/kraken/contrib/extract_lines.py @@ -1,8 +1,6 @@ #! /usr/bin/env python - import click - @click.command() @click.option('-f', '--format-type', type=click.Choice(['xml', 'alto', 'page', 'binary']), default='xml', help='Sets the input document format. In ALTO and PageXML mode all ' @@ -10,17 +8,8 @@ 'link to source images.') @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), help='Baseline detection model to use. Overrides format type and expects image files as input.') -@click.option('--repolygonize/--no-repolygonize', show_default=True, - default=False, help='Repolygonizes line data in ALTO/PageXML ' - 'files. This ensures that the trained model is compatible with the ' - 'segmenter in kraken even if the original image files either do ' - 'not contain anything but transcriptions and baseline information ' - 'or the polygon data was created using a different method. Will ' - 'be ignored in `path` mode. Note, that this option will be slow ' - 'and will not scale input images to the same size as the segmenter ' - 'does.') @click.argument('files', nargs=-1) -def cli(format_type, model, repolygonize, files): +def cli(format_type, model, files): """ A small script extracting rectified line polygons as defined in either ALTO or PageXML files or run a model to do the same. @@ -42,14 +31,14 @@ def cli(format_type, model, repolygonize, files): for doc in files: click.echo(f'Processing {doc} ', nl=False) if format_type != 'binary': - data = xml.preparse_xml_data([doc], format_type, repolygonize=repolygonize) - if len(data) > 0: - bounds = {'type': 'baselines', 'lines': [{'boundary': t['boundary'], 'baseline': t['baseline'], 'text': t['text']} for t in data]} - for idx, (im, box) in enumerate(segmentation.extract_polygons(Image.open(data[0]['image']), bounds)): + data = xml.XMLPage(doc, format_type) + if len(data.lines) > 0: + bounds = data.to_container() + for idx, (im, box) in enumerate(segmentation.extract_polygons(Image.open(bounds.imagename), bounds)): click.echo('.', nl=False) - im.save('{}.{}.jpg'.format(splitext(data[0]['image'])[0], idx)) - with open('{}.{}.gt.txt'.format(splitext(data[0]['image'])[0], idx), 'w') as fp: - fp.write(box['text']) + im.save('{}.{}.jpg'.format(splitext(bounds.imagename)[0], idx)) + with open('{}.{}.gt.txt'.format(splitext(bounds.imagename)[0], idx), 'w') as fp: + fp.write(box.text) else: with pa.memory_map(doc, 'rb') as source: ds_table = pa.ipc.open_file(source).read_all() From 7cf2dc404be2b8bebf915b7bab30ac193000732b Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 24 Jul 2023 13:53:04 +0200 Subject: [PATCH 64/68] More contrib scripts with containers --- kraken/contrib/heatmap_overlay.py | 1 - kraken/contrib/repolygonize.py | 10 ++--- kraken/contrib/segmentation_overlay.py | 60 +++++++++++--------------- 3 files changed, 28 insertions(+), 43 deletions(-) diff --git a/kraken/contrib/heatmap_overlay.py b/kraken/contrib/heatmap_overlay.py index 8e2c7236b..0b1c22197 100755 --- a/kraken/contrib/heatmap_overlay.py +++ b/kraken/contrib/heatmap_overlay.py @@ -4,7 +4,6 @@ """ import click - @click.command() @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), help='Baseline detection model to use.') diff --git a/kraken/contrib/repolygonize.py b/kraken/contrib/repolygonize.py index db40b8987..d3f498c57 100755 --- a/kraken/contrib/repolygonize.py +++ b/kraken/contrib/repolygonize.py @@ -85,10 +85,8 @@ def _repl_page(fname, polygons): doc.write(fp, encoding='UTF-8', xml_declaration=True) if format_type == 'page': - parse_fn = xml.parse_page repl_fn = _repl_page else: - parse_fn = xml.parse_alto repl_fn = _repl_alto topline = {'topline': True, @@ -97,11 +95,11 @@ def _repl_page(fname, polygons): for doc in files: click.echo(f'Processing {doc} ') - seg = parse_fn(doc) - im = Image.open(seg['image']).convert('L') + seg = xml.XMLPage(doc).to_container() + im = Image.open(seg.imagename).convert('L') baselines = [] - for x in seg['lines']: - bl = x['baseline'] if x['baseline'] is not None else [0, 0] + for x in seg.lines: + bl = x.baseline if x.baseline is not None else [0, 0] baselines.append(bl) o = calculate_polygonal_environment(im, baselines, scale=(1800, 0), topline=topline) repl_fn(doc, o) diff --git a/kraken/contrib/segmentation_overlay.py b/kraken/contrib/segmentation_overlay.py index 229862fae..a9ff235ff 100755 --- a/kraken/contrib/segmentation_overlay.py +++ b/kraken/contrib/segmentation_overlay.py @@ -7,6 +7,7 @@ import os import click import unicodedata +import dataclasses from itertools import cycle from collections import defaultdict @@ -27,10 +28,6 @@ def slugify(value): return value @click.command() -@click.option('-f', '--format-type', type=click.Choice(['xml', 'alto', 'page']), default='xml', - help='Sets the input document format. In ALTO and PageXML mode all ' - 'data is extracted from xml files containing both baselines, polygons, and a ' - 'link to source images.') @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), help='Baseline detection model to use. Overrides format type and expects image files as input.') @click.option('-d', '--text-direction', default='horizontal-lr', @@ -48,7 +45,7 @@ def slugify(value): 'and will not scale input images to the same size as the segmenter ' 'does.') @click.argument('files', nargs=-1) -def cli(format_type, model, text_direction, repolygonize, files): +def cli(model, text_direction, repolygonize, files): """ A script producing overlays of lines and regions from either ALTO or PageXML files or run a model to do the same. @@ -64,47 +61,38 @@ def cli(format_type, model, text_direction, repolygonize, files): from kraken import blla if model is None: - if format_type == 'xml': - fn = xml.parse_xml - elif format_type == 'alto': - fn = xml.parse_alto - else: - fn = xml.parse_page for doc in files: click.echo(f'Processing {doc} ', nl=False) - data = fn(doc) + data = xml.XMLPage(doc) if repolygonize: - im = Image.open(data['image']).convert('L') - lines = data['lines'] - polygons = segmentation.calculate_polygonal_environment(im, [x['baseline'] for x in lines], scale=(1200, 0)) - data['lines'] = [{'boundary': polygon, - 'baseline': orig['baseline'], - 'text': orig['text'], - 'tags': orig['tags']} for orig, polygon in zip(lines, polygons)] + im = Image.open(data.imagename).convert('L') + lines = data.lines + polygons = segmentation.calculate_polygonal_environment(im, [x.baseline for x in lines], scale=(1200, 0)) + data.lines = [dataclasses.replace(orig, boundary=polygon) for orig, polygon in zip(lines, polygons)] # reorder lines by type lines = defaultdict(list) - for line in data['lines']: - lines[line['tags']['type']].append(line) - im = Image.open(data['image']).convert('RGBA') + for line in data.lines: + lines[line.tags['type']].append(line) + im = Image.open(data.imagename).convert('RGBA') for t, ls in lines.items(): tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(tmp) for idx, line in enumerate(ls): c = next(cmap) - if line['boundary']: - draw.polygon([tuple(x) for x in line['boundary']], fill=c, outline=c[:3]) - if line['baseline']: - draw.line([tuple(x) for x in line['baseline']], fill=bmap, width=2, joint='curve') - draw.text(line['baseline'][0], str(idx), fill=(0, 0, 0, 255)) + if line.boundary: + draw.polygon([tuple(x) for x in line.boundary], fill=c, outline=c[:3]) + if line.baseline: + draw.line([tuple(x) for x in line.baseline], fill=bmap, width=2, joint='curve') + draw.text(line.baseline[0], str(idx), fill=(0, 0, 0, 255)) base_image = Image.alpha_composite(im, tmp) base_image.save(f'high_{os.path.basename(doc)}_lines_{slugify(t)}.png') - for t, regs in data['regions'].items(): + for t, regs in data.regions.items(): tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(tmp) for reg in regs: c = next(cmap) try: - draw.polygon(reg, fill=c, outline=c[:3]) + draw.polygon(reg.boundary, fill=c, outline=c[:3]) except Exception: pass base_image = Image.alpha_composite(im, tmp) @@ -118,26 +106,26 @@ def cli(format_type, model, text_direction, repolygonize, files): res = blla.segment(im, model=net, text_direction=text_direction) # reorder lines by type lines = defaultdict(list) - for line in res['lines']: - lines[line['tags']['type']].append(line) + for line in res.lines: + lines[line.tags['type']].append(line) im = im.convert('RGBA') for t, ls in lines.items(): tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(tmp) for idx, line in enumerate(ls): c = next(cmap) - draw.polygon([tuple(x) for x in line['boundary']], fill=c, outline=c[:3]) - draw.line([tuple(x) for x in line['baseline']], fill=bmap, width=2, joint='curve') - draw.text(line['baseline'][0], str(idx), fill=(0, 0, 0, 255)) + draw.polygon([tuple(x) for x in line.boundary], fill=c, outline=c[:3]) + draw.line([tuple(x) for x in line.baseline], fill=bmap, width=2, joint='curve') + draw.text(line.baseline[0], str(idx), fill=(0, 0, 0, 255)) base_image = Image.alpha_composite(im, tmp) base_image.save(f'high_{os.path.basename(doc)}_lines_{slugify(t)}.png') - for t, regs in res['regions'].items(): + for t, regs in res.regions.items(): tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(tmp) for reg in regs: c = next(cmap) try: - draw.polygon([tuple(x) for x in reg], fill=c, outline=c[:3]) + draw.polygon([tuple(x) for x in reg.boundary], fill=c, outline=c[:3]) except Exception: pass From 46d6c063b8a4bfa7a667eeedb4aec03d284a6fb3 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Wed, 26 Jul 2023 11:26:38 +0200 Subject: [PATCH 65/68] Switch recognition datasets to container classes --- kraken/containers.py | 12 +-- kraken/lib/dataset/recognition.py | 159 +++++++++++++++++++----------- 2 files changed, 108 insertions(+), 63 deletions(-) diff --git a/kraken/containers.py b/kraken/containers.py index 892d00109..c58ef9c10 100644 --- a/kraken/containers.py +++ b/kraken/containers.py @@ -54,7 +54,7 @@ class BaselineLine: paragraph direction) for the BiDi algorithm. Valid values are 'L' or 'R'. If None is given the default auto-resolution will be used. - image: Image object of this line. + imagename: Path to the image associated with the line. tags: A dict mapping types to values. split: Defines whether this line is in the `train`, `validation`, or `test` set during training. @@ -66,7 +66,7 @@ class BaselineLine: text: Optional[str] = None base_dir: Optional[Literal['L', 'R']] = None type: str = 'baselines' - image: Optional[PIL.Image.Image] = None + imagename: Optional[Union[str, PathLike]] = None tags: Optional[Dict[str, str]] = None split: Optional[Literal['train', 'validation', 'test']] = None regions: Optional[List[str]] = None @@ -89,7 +89,7 @@ class BBoxLine: paragraph direction) for the BiDi algorithm. Valid values are 'L' or 'R'. If None is given the default auto-resolution will be used. - image: Image object of this line. + imagename: Path to the image associated with the line.. tags: A dict mapping types to values. split: Defines whether this line is in the `train`, `validation`, or `test` set during training. @@ -105,7 +105,7 @@ class BBoxLine: text: Optional[str] = None base_dir: Optional[Literal['L', 'R']] = None type: str = 'bbox' - image: Optional[PIL.Image.Image] = None + imagename: Optional[Union[str, PathLike]] = None tags: Optional[Dict[str, str]] = None split: Optional[Literal['train', 'validation', 'test']] = None regions: Optional[List[str]] = None @@ -121,12 +121,12 @@ class Region: id: Unique identifier boundary: List of tuples `(x_n, y_n)` defining the bounding polygon of the region. The first and last points should be identical. - image: Image object containing the region. + imagename: Path to the image associated with the region. tags: A dict mapping types to values. """ id: str boundary: List[Tuple[int, int]] - image: Optional[PIL.Image.Image] = None + imagename: Optional[Union[str, PathLike]] = None tags: Optional[Dict[str, str]] = None diff --git a/kraken/lib/dataset/recognition.py b/kraken/lib/dataset/recognition.py index d381b9d97..95bdef25e 100644 --- a/kraken/lib/dataset/recognition.py +++ b/kraken/lib/dataset/recognition.py @@ -56,7 +56,7 @@ def __init__(self): ShiftScaleRotate, OpticalDistortion, ElasticTransform, PixelDropout ) - + self._transforms = Compose([ ToFloat(), PixelDropout(p=0.2), @@ -71,7 +71,7 @@ def __init__(self): ElasticTransform(alpha=64, sigma=25, alpha_affine=0.25, p=0.1), ], p=0.2), ], p=0.5) - + def __call__(self, image): return self._transforms(image=image) @@ -319,54 +319,67 @@ def __init__(self, self.im_mode = '1' - def add(self, *args, **kwargs): + def add(self, + line: Optional[BaselineLine] = None, + page: Optional[Segmentation] = None): """ - Adds a line to the dataset. + Adds an indiviual line or all lines on a page to the dataset. Args: - im (path): Path to the whole page image - text (str): Transcription of the line. - baseline (list): A list of coordinates [[x0, y0], ..., [xn, yn]]. - boundary (list): A polygon mask for the line. + line: BaselineLine container object of a line. + page: Segmentation container object for a page. """ - if 'preparse' not in kwargs or not kwargs['preparse']: - kwargs = self.parse(*args, **kwargs) - self._images.append((kwargs['image'], kwargs['baseline'], kwargs['boundary'])) - self._gt.append(kwargs['text']) - self.alphabet.update(kwargs['text']) + if line: + self.add_line(line) + if page: + self.add_page(page) + if not (line and page): + raise ValueError('Neither line nor page data provided in dataset builder') - def parse(self, - image: Union[PathLike, str, Image.Image], - text: str, - baseline: List[Tuple[int, int]], - boundary: List[Tuple[int, int]], - *args, - **kwargs): + def add_page(self, page: Segmentation): """ - Parses a sample for the dataset and returns it. + Adds all lines on a page to the dataset. - This function is mainly uses for parallelized loading of training data. + Invalid lines will be skipped and a warning will be printed. Args: - im (path): Path to the whole page image - text (str): Transcription of the line. - baseline (list): A list of coordinates [[x0, y0], ..., [xn, yn]]. - boundary (list): A polygon mask for the line. + page: Segmentation container object for a page. """ - orig_text = text + if page.type != 'baselines': + raise ValueError(f'Invalid segmentation of type {page.type} (expected "baselines")') + for line in page.lines: + try: + self.add_line(dataclasses.replace(line, imagename=page.imagename)) + except ValueError as e: + logger.warning(e) + + def add_line(self, line: BaselineLine) + """ + Adds a line to the dataset. + + Args: + line: BaselineLine container object for a line. + + Raises: + ValueError if the transcription of the line is empty after + transformation or either baseline or bounding polygon are missing. + """ + if line.type != 'baselines': + raise ValueError(f'Invalid line of type {line.type} (expected "baselines")') + + text = line.text for func in self.text_transforms: text = func(text) if not text and self.skip_empty_lines: - raise KrakenInputException(f'Text line "{orig_text}" is empty after transformations') - if not baseline: - raise KrakenInputException('No baseline given for line') - if not boundary: - raise KrakenInputException('No boundary given for line') - return {'text': text, - 'image': image, - 'baseline': baseline, - 'boundary': boundary, - 'preparse': True} + raise ValueError(f'Text line "{line.text}" is empty after transformations') + if not line.baseline: + raise ValueError('No baseline given for line') + if not line.boundary: + raise ValueError('No boundary given for line') + + self._images.append((line.image, line.baseline, line.boundary)) + self._gt.append(text) + self.alphabet.update(text) def encode(self, codec: Optional[PytorchCodec] = None) -> None: """ @@ -493,35 +506,67 @@ def __init__(self, split: Callable[[Union[PathLike, str]], str] = F_t.default_sp self.im_mode = '1' - def add(self, *args, **kwargs) -> None: + def add(self, + line: Optional[BBoxLine] = None, + page: Optional[Segmentation] = None): """ - Adds a line-image-text pair to the dataset. + Adds an indiviual line or all lines on a page to the dataset. Args: - image (str): Input image path + line: BBoxLine container object of a line. + page: Segmentation container object for a page. """ - if 'preparse' not in kwargs or not kwargs['preparse']: - kwargs = self.parse(*args, **kwargs) - self._images.append(kwargs['image']) - self._gt.append(kwargs['text']) - self.alphabet.update(kwargs['text']) + if line: + self.add_line(line) + if page: + self.add_page(page) + if not (line and page): + raise ValueError('Neither line nor page data provided in dataset builder') - def parse(self, image: Union[PathLike, str, Image.Image], *args, **kwargs) -> Dict: + def add_page(self, page: Segmentation): """ - Parses a sample for this dataset. + Adds all lines on a page to the dataset. - This is mostly used to parallelize populating the dataset. + Invalid lines will be skipped and a warning will be printed. Args: - image (str): Input image path - """ - with open(self.split(image), 'r', encoding='utf-8') as fp: - text = fp.read().strip('\n\r') - for func in self.text_transforms: - text = func(text) - if not text and self.skip_empty_lines: - raise KrakenInputException(f'Text line is empty ({fp.name})') - return {'image': image, 'text': text, 'preparse': True} + page: Segmentation container object for a page. + """ + if page.type != 'bbox': + raise ValueError(f'Invalid segmentation of type {page.type} (expected "bbox")') + for line in page.lines: + try: + self.add_line(dataclasses.replace(line, imagename=page.imagename)) + except ValueError as e: + logger.warning(e) + + def add_line(self, line: BBoxLine) + """ + Adds a line to the dataset. + + Args: + line: BBoxLine container object for a line. + + Raises: + ValueError if the transcription of the line is empty after + transformation or either baseline or bounding polygon are missing. + """ + if line.type != 'bbox': + raise ValueError(f'Invalid line of type {line.type} (expected "bbox")') + + text = line.text + for func in self.text_transforms: + text = func(text) + if not text and self.skip_empty_lines: + raise ValueError(f'Text line "{line.text}" is empty after transformations') + if not line.baseline: + raise ValueError('No baseline given for line') + if not line.boundary: + raise ValueError('No boundary given for line') + + self._images.append(line.image) + self._gt.append(text) + self.alphabet.update(text) def encode(self, codec: Optional[PytorchCodec] = None) -> None: """ From a5a8a20932ea2cc4a1f8777c281d7e193542d682 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Wed, 30 Aug 2023 23:01:57 +0200 Subject: [PATCH 66/68] s/preparse_xml_data/XMLPage/g --- docs/api_docs.rst | 227 +++++++++++++++++------------ kraken/ketos/recognition.py | 4 +- kraken/lib/dataset/recognition.py | 14 +- kraken/lib/dataset/segmentation.py | 98 +++---------- kraken/lib/pretrain/model.py | 42 ++++-- kraken/lib/train.py | 48 +++--- 6 files changed, 207 insertions(+), 226 deletions(-) diff --git a/docs/api_docs.rst b/docs/api_docs.rst index 46379f2b8..cb85ff91f 100644 --- a/docs/api_docs.rst +++ b/docs/api_docs.rst @@ -2,8 +2,11 @@ API Reference ************* +Segmentation +============ + kraken.blla module -================== +------------------ .. note:: @@ -14,7 +17,7 @@ kraken.blla module .. autoapifunction:: kraken.blla.segment kraken.pageseg module -===================== +--------------------- .. note:: @@ -24,22 +27,22 @@ kraken.pageseg module .. autoapifunction:: kraken.pageseg.segment -kraken.rpred module -=================== +Recognition +=========== -.. autoapifunction:: kraken.rpred.bidi_record +kraken.rpred module +------------------- .. autoapiclass:: kraken.rpred.mm_rpred :members: -.. autoapiclass:: kraken.rpred.ocr_record - :members: - .. autoapifunction:: kraken.rpred.rpred +Serialization +============= kraken.serialization module -=========================== +--------------------------- .. autoapifunction:: kraken.serialization.render_report @@ -47,127 +50,118 @@ kraken.serialization module .. autoapifunction:: kraken.serialization.serialize_segmentation -kraken.lib.models module -======================== +Default templates +----------------- -.. autoapiclass:: kraken.lib.models.TorchSeqRecognizer - :members: +ALTO 4.4 +^^^^^^^^ -.. autoapifunction:: kraken.lib.models.load_any +.. literalinclude:: ../../templates/alto + :language: xml+jinja -kraken.lib.vgsl module -====================== +PageXML +^^^^^^^ -.. autoapiclass:: kraken.lib.vgsl.TorchVGSLModel - :members: +.. literalinclude:: ../../templates/alto + :language: xml+jinja -kraken.lib.xml module -===================== +hOCR +^^^^ -.. autoapifunction:: kraken.lib.xml.parse_xml +.. literalinclude:: ../../templates/alto + :language: xml+jinja -.. autoapifunction:: kraken.lib.xml.parse_page +ABBYY XML +^^^^^^^^^ -.. autoapifunction:: kraken.lib.xml.parse_alto +.. literalinclude:: ../../templates/abbyyxml + :language: xml+jinja + +Containers and Helpers +====================== kraken.lib.codec module -======================= +----------------------- .. autoapiclass:: kraken.lib.codec.PytorchCodec :members: -kraken.lib.train module -======================= +kraken.containers module +------------------------ -Training Schedulers -------------------- +.. autoapiclass:: kraken.containers.Segmentation + :members: -.. autoapiclass:: kraken.lib.train.TrainScheduler - :members: +.. autoapiclass:: kraken.containers.BaselineLine + :members: -.. autoapiclass:: kraken.lib.train.annealing_step - :members: +.. autoapiclass:: kraken.containers.BBoxLine + :members: -.. autoapiclass:: kraken.lib.train.annealing_const - :members: +.. autoapiclass:: kraken.containers.ocr_record + :members: -.. autoapiclass:: kraken.lib.train.annealing_exponential - :members: +.. autoapiclass:: kraken.containers.BaselineOCRRecord + :members: -.. autoapiclass:: kraken.lib.train.annealing_reduceonplateau - :members: +.. autoapiclass:: kraken.containers.BBoxOCRRecord + :members: -.. autoapiclass:: kraken.lib.train.annealing_cosine - :members: +.. autoapiclass:: kraken.containers.ProcessingStep + :members: -.. autoapiclass:: kraken.lib.train.annealing_onecycle - :members: +kraken.lib.ctc_decoder +---------------------- -Training Stoppers ------------------ +.. autoapifunction:: kraken.lib.ctc_decoder.beam_decoder -.. autoapiclass:: kraken.lib.train.TrainStopper - :members: +.. autoapifunction:: kraken.lib.ctc_decoder.greedy_decoder -.. autoapiclass:: kraken.lib.train.EarlyStopping - :members: +.. autoapifunction:: kraken.lib.ctc_decoder.blank_threshold_decoder -.. autoapiclass:: kraken.lib.train.EpochStopping - :members: +kraken.lib.exceptions +--------------------- -.. autoapiclass:: kraken.lib.train.NoStopping +.. autoapiclass:: kraken.lib.exceptions.KrakenCodecException :members: -Loss and Evaluation Functions ------------------------------ - -.. autoapifunction:: kraken.lib.train.recognition_loss_fn - -.. autoapifunction:: kraken.lib.train.baseline_label_loss_fn - -.. autoapifunction:: kraken.lib.train.recognition_evaluator_fn - -.. autoapifunction:: kraken.lib.train.baseline_label_evaluator_fn - -Trainer -------- - -.. autoapiclass:: kraken.lib.train.KrakenTrainer +.. autoapiclass:: kraken.lib.exceptions.KrakenStopTrainingException :members: +.. autoapiclass:: kraken.lib.exceptions.KrakenEncodeException + :members: -kraken.lib.dataset module -========================= - -Datasets --------- +.. autoapiclass:: kraken.lib.exceptions.KrakenRecordException + :members: -.. autoapiclass:: kraken.lib.dataset.BaselineSet +.. autoapiclass:: kraken.lib.exceptions.KrakenInvalidModelException :members: -.. autoapiclass:: kraken.lib.dataset.PolygonGTDataset +.. autoapiclass:: kraken.lib.exceptions.KrakenInputException :members: -.. autoapiclass:: kraken.lib.dataset.GroundTruthDataset +.. autoapiclass:: kraken.lib.exceptions.KrakenRepoException :members: -Helpers -------- +.. autoapiclass:: kraken.lib.exceptions.KrakenCairoSurfaceException + :members: -.. autoapifunction:: kraken.lib.dataset.compute_error +kraken.lib.models module +------------------------ -.. autoapifunction:: kraken.lib.dataset.preparse_xml_data +.. autoapiclass:: kraken.lib.models.TorchSeqRecognizer + :members: -.. autoapifunction:: kraken.lib.dataset.generate_input_transforms +.. autoapifunction:: kraken.lib.models.load_any kraken.lib.segmentation module ------------------------------ .. autoapifunction:: kraken.lib.segmentation.reading_order -.. autoapifunction:: kraken.lib.segmentation.polygonal_reading_order +.. autoapifunction:: kraken.lib.segmentation.neural_reading_order -.. autoapifunction:: kraken.lib.segmentation.denoising_hysteresis_thresh +.. autoapifunction:: kraken.lib.segmentation.polygonal_reading_order .. autoapifunction:: kraken.lib.segmentation.vectorize_lines @@ -181,43 +175,82 @@ kraken.lib.segmentation module .. autoapifunction:: kraken.lib.segmentation.extract_polygons +kraken.lib.vgsl module +---------------------- -kraken.lib.ctc_decoder -====================== +.. autoapiclass:: kraken.lib.vgsl.TorchVGSLModel + :members: -.. autoapifunction:: kraken.lib.ctc_decoder.beam_decoder +kraken.lib.xml module +--------------------- -.. autoapifunction:: kraken.lib.ctc_decoder.greedy_decoder +.. autoapiclass:: kraken.lib.xml.XMLPage -.. autoapifunction:: kraken.lib.ctc_decoder.blank_threshold_decoder +Training +======== -kraken.lib.exceptions -===================== +kraken.lib.train module +----------------------- -.. autoapiclass:: kraken.lib.exceptions.KrakenCodecException - :members: +Loss and Evaluation Functions +----------------------------- -.. autoapiclass:: kraken.lib.exceptions.KrakenStopTrainingException +.. autoapifunction:: kraken.lib.train.recognition_loss_fn + +.. autoapifunction:: kraken.lib.train.baseline_label_loss_fn + +.. autoapifunction:: kraken.lib.train.recognition_evaluator_fn + +.. autoapifunction:: kraken.lib.train.baseline_label_evaluator_fn + +Trainer +------- + +.. autoapiclass:: kraken.lib.train.KrakenTrainer :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenEncodeException + +kraken.lib.dataset module +------------------------- + +Recognition datasets +^^^^^^^^^^^^^^^^^^^^ + +.. autoapiclass:: kraken.lib.dataset.ArrowIPCRecognitionDataset :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenRecordException +.. autoapiclass:: kraken.lib.dataset.BaselineSet :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenInvalidModelException +.. autoapiclass:: kraken.lib.dataset.GroundTruthDataset :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenInputException +Segmentation datasets +^^^^^^^^^^^^^^^^^^^^^ + +.. autoapiclass:: kraken.lib.dataset.PolygonGTDataset :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenRepoException +Reading order datasets +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoapiclass:: kraken.lib.dataset.PairWiseROSet :members: -.. autoapiclass:: kraken.lib.exceptions.KrakenCairoSurfaceException +.. autoapiclass:: kraken.lib.dataset.PageWiseROSet :members: +Helpers +^^^^^^^ + +.. autoapiclass:: kraken.lib.dataset.ImageInputTransforms + :members: + +.. autoapifunction:: kraken.lib.dataset.collate_sequences + +.. autoapifunction:: kraken.lib.dataset.global_align + +.. autoapifunction:: kraken.lib.dataset.compute_confusions Legacy modules ============== diff --git a/kraken/ketos/recognition.py b/kraken/ketos/recognition.py index 2d8eaf86b..ba324e810 100644 --- a/kraken/ketos/recognition.py +++ b/kraken/ketos/recognition.py @@ -384,7 +384,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, from kraken.serialization import render_report from kraken.lib import models - from kraken.lib.xml import preparse_xml_data + from kraken.lib.xml import XMLPage from kraken.lib.dataset import (global_align, compute_confusions, PolygonGTDataset, GroundTruthDataset, ImageInputTransforms, @@ -413,7 +413,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, if format_type in ['xml', 'page', 'alto']: if repolygonize: message('Repolygonizing data') - test_set = preparse_xml_data(test_set, format_type, repolygonize) + test_set = [{'page': XMLPage(file, filetype=format_type).to_container()} for file in test_set] valid_norm = False DatasetClass = PolygonGTDataset elif format_type == 'binary': diff --git a/kraken/lib/dataset/recognition.py b/kraken/lib/dataset/recognition.py index 95bdef25e..0ddb5585e 100644 --- a/kraken/lib/dataset/recognition.py +++ b/kraken/lib/dataset/recognition.py @@ -30,6 +30,7 @@ from torch.utils.data import Dataset from typing import Dict, List, Tuple, Callable, Optional, Any, Union, Literal +from kraken.containers import BaselineLine, BBoxLine, Segmentation from kraken.lib.util import is_bitonal from kraken.lib.codec import PytorchCodec from kraken.lib.segmentation import extract_polygons @@ -353,7 +354,7 @@ def add_page(self, page: Segmentation): except ValueError as e: logger.warning(e) - def add_line(self, line: BaselineLine) + def add_line(self, line: BaselineLine): """ Adds a line to the dataset. @@ -445,8 +446,7 @@ class GroundTruthDataset(Dataset): All data is cached in memory. """ - def __init__(self, split: Callable[[Union[PathLike, str]], str] = F_t.default_split, - suffix: str = '.gt.txt', + def __init__(self, normalization: Optional[str] = None, whitespace_normalization: bool = True, skip_empty_lines: bool = True, @@ -457,10 +457,6 @@ def __init__(self, split: Callable[[Union[PathLike, str]], str] = F_t.default_sp Reads a list of image-text pairs and creates a ground truth set. Args: - split: Function for generating the base name without - extensions from paths - suffix: Suffix to attach to image base name for text - retrieval mode: Image color space. Either RGB (color) or L (grayscale/bw). Only L is compatible with vertical scaling/dewarping. @@ -479,8 +475,6 @@ def __init__(self, split: Callable[[Union[PathLike, str]], str] = F_t.default_sp tensor suitable for forward passes. augmentation: Enables augmentation. """ - self.suffix = suffix - self.split = partial(F_t.suffix_split, split=split, suffix=suffix) self._images = [] # type: Union[List[Image], List[torch.Tensor]] self._gt = [] # type: List[str] self.alphabet = Counter() # type: Counter @@ -540,7 +534,7 @@ def add_page(self, page: Segmentation): except ValueError as e: logger.warning(e) - def add_line(self, line: BBoxLine) + def add_line(self, line: BBoxLine): """ Adds a line to the dataset. diff --git a/kraken/lib/dataset/segmentation.py b/kraken/lib/dataset/segmentation.py index faa0536ce..0d248800e 100644 --- a/kraken/lib/dataset/segmentation.py +++ b/kraken/lib/dataset/segmentation.py @@ -33,6 +33,7 @@ from skimage.draw import polygon +from kraken.containers import Segmentation from kraken.lib.xml import XMLPage from kraken.lib.exceptions import KrakenInputException @@ -48,24 +49,20 @@ class BaselineSet(Dataset): """ Dataset for training a baseline/region segmentation model. """ - def __init__(self, imgs: Sequence[Union[PathLike, str]] = None, - suffix: str = '.path', + def __init__(self, line_width: int = 4, padding: Tuple[int, int, int, int] = (0, 0, 0, 0), im_transforms: Callable[[Any], torch.Tensor] = transforms.Compose([]), - mode: Optional[Literal['path', 'alto', 'page', 'xml']] = 'path', + mode: Optional[Literal['alto', 'page', 'xml']] = 'xml', augmentation: bool = False, valid_baselines: Sequence[str] = None, merge_baselines: Dict[str, Sequence[str]] = None, valid_regions: Sequence[str] = None, merge_regions: Dict[str, Sequence[str]] = None): """ - Reads a list of image-json pairs and creates a data set. + Creates a dataset for a text-line and region segmentation model. Args: - imgs: - suffix: Suffix to attach to image base name to load JSON files - from. line_width: Height of the baseline in the scaled input. padding: Tuple of ints containing the left/right, top/bottom padding of the input images. @@ -90,7 +87,6 @@ def __init__(self, imgs: Sequence[Union[PathLike, str]] = None, self.mode = mode self.im_mode = '1' self.pad = padding - self.aug = None self.targets = [] # n-th entry contains semantic of n-th class self.class_mapping = {'aux': {'_start_separator': 0, '_end_separator': 1}, 'baselines': {}, 'regions': {}} @@ -102,53 +98,8 @@ def __init__(self, imgs: Sequence[Union[PathLike, str]] = None, self.mreg_dict = merge_regions if merge_regions is not None else {} self.valid_baselines = valid_baselines self.valid_regions = valid_regions - if mode in ['alto', 'page', 'xml']: - im_paths = [] - self.targets = [] - for img in imgs: - try: - data = XMLPage(img) - im_paths.append(data.imagename) - lines = defaultdict(list) - for line in data.get_sorted_lines(): - if valid_baselines is None or set(line['tags'].values()).intersection(valid_baselines): - tags = set(line['tags'].values()).intersection(valid_baselines) if valid_baselines else line['tags'].values() - for tag in tags: - lines[self.mbl_dict.get(tag, tag)].append(line['baseline']) - self.class_stats['baselines'][self.mbl_dict.get(tag, tag)] += 1 - regions = defaultdict(list) - for k, v in data.regions.items(): - if valid_regions is None or k in valid_regions: - regions[self.mreg_dict.get(k, k)].extend(v) - self.class_stats['regions'][self.mreg_dict.get(k, k)] += len(v) - self.targets.append({'baselines': lines, 'regions': regions}) - except KrakenInputException as e: - logger.warning(e) - continue - # get line types - imgs = im_paths - # calculate class mapping - line_types = set() - region_types = set() - for page in self.targets: - for line_type in page['baselines'].keys(): - line_types.add(line_type) - for reg_type in page['regions'].keys(): - region_types.add(reg_type) - idx = -1 - for idx, line_type in enumerate(line_types): - self.class_mapping['baselines'][line_type] = idx + self.num_classes - self.num_classes += idx + 1 - idx = -1 - for idx, reg_type in enumerate(region_types): - self.class_mapping['regions'][reg_type] = idx + self.num_classes - self.num_classes += idx + 1 - elif mode == 'path': - pass - elif mode is None: - imgs = [] - else: - raise Exception('invalid dataset mode') + + self.aug = None if augmentation: import cv2 cv2.setNumThreads(0) @@ -172,37 +123,26 @@ def __init__(self, imgs: Sequence[Union[PathLike, str]] = None, ], p=0.2), HueSaturationValue(hue_shift_limit=20, sat_shift_limit=0.1, val_shift_limit=0.1, p=0.3), ], p=0.5) - self.imgs = imgs self.line_width = line_width self.transforms = im_transforms self.seg_type = None - def add(self, - image: Union[PathLike, str, Image.Image], - baselines: List[List[List[Tuple[int, int]]]] = None, - regions: Dict[str, List[List[Tuple[int, int]]]] = None, - *args, - **kwargs): + def add(self, doc: Union[Segmentation, XMLPage]): """ Adds a page to the dataset. Args: - im: Path to the whole page image - baseline: A list containing dicts with a list of coordinates - and tags [{'baseline': [[x0, y0], ..., - [xn, yn]], 'tags': ('script_type',)}, ...] - regions: A dict containing list of lists of coordinates - {'region_type_0': [[x0, y0], ..., [xn, yn]]], - 'region_type_1': ...}. + doc: Either a Segmentation container class or an XMLPage. """ - if self.mode: - raise Exception(f'The `add` method is incompatible with dataset mode {self.mode}') + if doc.type != 'baselines': + raise ValueError(f'{doc} is of type {doc.type}. Expected "baselines".') + baselines_ = defaultdict(list) - for line in baselines: - if self.valid_baselines is None or set(line['tags'].values()).intersection(self.valid_baselines): - tags = set(line['tags'].values()).intersection(self.valid_baselines) if self.valid_baselines else line['tags'].values() + for line in doc.lines: + if self.valid_baselines is None or set(line.tags.values()).intersection(self.valid_baselines): + tags = set(line.tags.values()).intersection(self.valid_baselines) if self.valid_baselines else line.tags.values() for tag in tags: - baselines_[tag].append(line['baseline']) + baselines_[tag].append(line.baseline) self.class_stats['baselines'][tag] += 1 if tag not in self.class_mapping['baselines']: @@ -210,7 +150,7 @@ def add(self, self.class_mapping['baselines'][tag] = self.num_classes - 1 regions_ = defaultdict(list) - for k, v in regions.items(): + for k, v in doc.regions.items(): reg_type = self.mreg_dict.get(k, k) if self.valid_regions is None or reg_type in self.valid_regions: regions_[reg_type].extend(v) @@ -224,11 +164,7 @@ def add(self, def __getitem__(self, idx): im = self.imgs[idx] - if self.mode != 'path': - target = self.targets[idx] - else: - with open('{}.path'.format(path.splitext(im)[0]), 'r') as fp: - target = json.load(fp) + target = self.targets[idx] if not isinstance(im, Image.Image): try: logger.debug(f'Attempting to load {im}') diff --git a/kraken/lib/pretrain/model.py b/kraken/lib/pretrain/model.py index 3b5087e96..61dc04ab4 100644 --- a/kraken/lib/pretrain/model.py +++ b/kraken/lib/pretrain/model.py @@ -45,7 +45,7 @@ from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda from kraken.lib import vgsl, default_specs, layers -from kraken.lib.xml import preparse_xml_data +from kraken.lib.xml import XMLPage from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import (ArrowIPCRecognitionDataset, GroundTruthDataset, PolygonGTDataset, @@ -108,10 +108,10 @@ def __init__(self, valid_norm = True if format_type in ['xml', 'page', 'alto']: logger.info(f'Parsing {len(training_data)} XML files for training data') - training_data = preparse_xml_data(training_data, format_type, repolygonize) + training_data = [{'page': XMLPage(file, format_type).to_container()} for file in training_data] if evaluation_data: logger.info(f'Parsing {len(evaluation_data)} XML files for validation data') - evaluation_data = preparse_xml_data(evaluation_data, format_type, repolygonize) + evaluation_data = [{'page': XMLPage(file, format_type).to_container()} for file in evaluation_data] if binary_dataset_split: logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') binary_dataset_split = False @@ -144,7 +144,7 @@ def __init__(self, valid_norm = True # format_type is None. Determine training type from length of training data entry elif not format_type: - if len(training_data[0]) >= 4: + if training_data[0].type == 'baselines': DatasetClass = PolygonGTDataset valid_norm = False else: @@ -156,6 +156,22 @@ def __init__(self, if binary_dataset_split: logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') binary_dataset_split = False + samples = [] + for sample in training_data: + if isinstance(sample, Segmentation): + samples.append({'page': sample}) + else: + samples.append({'line': sample}) + training_data = samples + if evaluation_data: + samples = [] + for sample in evaluation_data: + if isinstance(sample, Segmentation): + samples.append({'page': sample}) + else: + samples.append({'line': sample}) + evaluation_data = samples + else: raise ValueError(f'format_type {format_type} not in [alto, page, xml, path, binary].') @@ -203,18 +219,12 @@ def _build_dataset(self, skip_empty_lines=False, **kwargs) - if (self.hparams.num_workers and self.hparams.num_workers > 1) and self.hparams.format_type != 'binary': - with Pool(processes=self.hparams.num_workers) as pool: - for im in pool.imap_unordered(partial(_star_fun, dataset.parse), training_data, 5): - logger.debug(f'Adding sample {im} to training set') - if im: - dataset.add(**im) - else: - for im in training_data: - try: - dataset.add(**im) - except KrakenInputException as e: - logger.warning(str(e)) + for sample in training_data: + try: + dataset.add(**sample) + except KrakenInputException as e: + logger.warning(str(e)) + return dataset def train_dataloader(self): diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 81296eb27..f94568f78 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -33,7 +33,6 @@ from pytorch_lightning.callbacks import Callback, EarlyStopping, BaseFinetuning, LearningRateMonitor from kraken.lib import models, vgsl, default_specs, progress -# from kraken.lib.xml import preparse_xml_data from kraken.lib.util import make_printable from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import (ArrowIPCRecognitionDataset, BaselineSet, @@ -271,10 +270,10 @@ def __init__(self, valid_norm = True if format_type in ['xml', 'page', 'alto']: logger.info(f'Parsing {len(training_data)} XML files for training data') - training_data = preparse_xml_data(training_data, format_type, repolygonize) + training_data = [{'page': XMLPage(file, format_type).to_container()} for file in training_data] if evaluation_data: logger.info(f'Parsing {len(evaluation_data)} XML files for validation data') - evaluation_data = preparse_xml_data(evaluation_data, format_type, repolygonize) + evaluation_data = [{'page': XMLPage(file, format_type).to_container()} for file in evaluation_data] if binary_dataset_split: logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') binary_dataset_split = False @@ -305,9 +304,9 @@ def __init__(self, logger.info(f'Got {len(evaluation_data)} line strip images for validation data') evaluation_data = [{'image': im} for im in evaluation_data] valid_norm = True - # format_type is None. Determine training type from length of training data entry + # format_type is None. Determine training type from container class types elif not format_type: - if len(training_data[0]) >= 4: + if training_data[0].type == 'baselines': DatasetClass = PolygonGTDataset valid_norm = False else: @@ -319,6 +318,21 @@ def __init__(self, if binary_dataset_split: logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.') binary_dataset_split = False + samples = [] + for sample in training_data: + if isinstance(sample, Segmentation): + samples.append({'page': sample}) + else: + samples.append({'line': sample}) + training_data = samples + if evaluation_data: + samples = [] + for sample in evaluation_data: + if isinstance(sample, Segmentation): + samples.append({'page': sample}) + else: + samples.append({'line': sample}) + evaluation_data = samples else: raise ValueError(f'format_type {format_type} not in [alto, page, xml, path, binary].') @@ -423,21 +437,15 @@ def _build_dataset(self, augmentation=self.hparams.hyper_params['augment'], **kwargs) - if (self.num_workers and self.num_workers > 1) and self.format_type != 'binary': - with Pool(processes=self.num_workers) as pool: - for im in pool.imap_unordered(partial(_star_fun, dataset.parse), training_data, 5): - logger.debug(f'Adding sample {im} to training set') - if im: - dataset.add(**im) - else: - for im in training_data: - try: - dataset.add(**im) - except KrakenInputException as e: - logger.warning(str(e)) - if self.format_type == 'binary' and self.hparams.hyper_params['normalization']: - logger.debug('Rebuilding dataset using unicode normalization') - dataset.rebuild_alphabet() + for sample in training_data: + try: + dataset.add(**sample) + except KrakenInputException as e: + logger.warning(str(e)) + if self.format_type == 'binary' and self.hparams.hyper_params['normalization']: + logger.debug('Rebuilding dataset using unicode normalization') + dataset.rebuild_alphabet() + return dataset def forward(self, x, seq_lens=None): From 0fddaace613bb6ee3e8578fa353b810100931a9b Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 18 Sep 2023 11:52:48 +0200 Subject: [PATCH 67/68] Move progress bar imports around to prevent torch import Would drastically slow down display of help message in CLI drivers --- kraken/ketos/recognition.py | 2 +- kraken/ketos/repo.py | 3 +-- kraken/ketos/ro.py | 4 ++-- kraken/ketos/segmentation.py | 2 +- kraken/ketos/transcription.py | 3 ++- kraken/kraken.py | 29 ++++++++++++++++------------- 6 files changed, 23 insertions(+), 20 deletions(-) diff --git a/kraken/ketos/recognition.py b/kraken/ketos/recognition.py index ba324e810..04b6a30b0 100644 --- a/kraken/ketos/recognition.py +++ b/kraken/ketos/recognition.py @@ -24,7 +24,6 @@ from typing import List -from kraken.lib.progress import KrakenProgressBar from kraken.lib.exceptions import KrakenInputException from kraken.lib.default_specs import RECOGNITION_HYPER_PARAMS, RECOGNITION_SPEC from .util import _validate_manifests, _expand_gt, message, to_ptl_device @@ -390,6 +389,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, ImageInputTransforms, ArrowIPCRecognitionDataset, collate_sequences) + from kraken.lib.progress import KrakenProgressBar logger.info('Building test set from {} line images'.format(len(test_set) + len(evaluation_files))) diff --git a/kraken/ketos/repo.py b/kraken/ketos/repo.py index 32a49ac5e..52c9db8e3 100644 --- a/kraken/ketos/repo.py +++ b/kraken/ketos/repo.py @@ -22,8 +22,6 @@ import click import logging -from kraken.lib.progress import KrakenDownloadProgressBar - from .util import message logging.captureWarnings(True) @@ -52,6 +50,7 @@ def publish(ctx, metadata, access_token, private, model): from kraken import repo from kraken.lib import models + from kraken.lib.progress import KrakenDownloadProgressBar with pkg_resources.resource_stream('kraken', 'metadata.schema.json') as fp: schema = json.load(fp) diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py index 1dcc0856f..17fc8cf09 100644 --- a/kraken/ketos/ro.py +++ b/kraken/ketos/ro.py @@ -25,7 +25,6 @@ from PIL import Image from typing import Dict -from kraken.lib.progress import KrakenProgressBar from kraken.lib.exceptions import KrakenInputException from kraken.lib.default_specs import READING_ORDER_HYPER_PARAMS @@ -152,8 +151,9 @@ def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag, """ import shutil - from kraken.lib.train import KrakenTrainer from kraken.lib.ro import ROModel + from kraken.lib.train import KrakenTrainer + from kraken.lib.progress import KrakenProgressBar if not (0 <= freq <= 1) and freq % 1.0 != 0: raise click.BadOptionUsage('freq', 'freq needs to be either in the interval [0,1.0] or a positive integer.') diff --git a/kraken/ketos/segmentation.py b/kraken/ketos/segmentation.py index dfcb7c739..e356d08df 100644 --- a/kraken/ketos/segmentation.py +++ b/kraken/ketos/segmentation.py @@ -24,7 +24,6 @@ from PIL import Image -from kraken.lib.progress import KrakenProgressBar from kraken.lib.exceptions import KrakenInputException from kraken.lib.default_specs import SEGMENTATION_HYPER_PARAMS, SEGMENTATION_SPEC @@ -230,6 +229,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, import shutil from kraken.lib.train import SegmentationModel, KrakenTrainer + from kraken.lib.progress import KrakenProgressBar if resize != 'fail' and not load: raise click.BadOptionUsage('resize', 'resize option requires loading an existing model') diff --git a/kraken/ketos/transcription.py b/kraken/ketos/transcription.py index 490c0ac4e..dd4402c66 100644 --- a/kraken/ketos/transcription.py +++ b/kraken/ketos/transcription.py @@ -27,7 +27,6 @@ from typing import IO, Any, cast from bidi.algorithm import get_display -from kraken.lib.progress import KrakenProgressBar from .util import message logging.captureWarnings(True) @@ -68,6 +67,7 @@ def extract(ctx, binarize, normalization, normalize_whitespace, reorder, from lxml import html, etree from kraken import binarization + from kraken.lib.progress import KrakenProgressBar try: os.mkdir(output) @@ -172,6 +172,7 @@ def transcription(ctx, text_direction, scale, bw, maxcolseps, from kraken import binarization from kraken.lib import models + from kraken.lib.progress import KrakenProgressBar ti = transcribe.TranscriptionInterface(font, font_style) diff --git a/kraken/kraken.py b/kraken/kraken.py index 0f8fea361..38647eee8 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -24,16 +24,16 @@ import dataclasses import pkg_resources -from typing import Dict, Union, List, cast, Any, IO, Callable +from PIL import Image from pathlib import Path -from rich.traceback import install from functools import partial -from PIL import Image +from rich.traceback import install +from threadpoolctl import threadpool_limits +from typing import Dict, Union, List, cast, Any, IO, Callable import click from kraken.lib import log -from kraken.lib.progress import KrakenProgressBar, KrakenDownloadProgressBar warnings.simplefilter('ignore', UserWarning) @@ -118,8 +118,7 @@ def segmenter(legacy, model, text_direction, scale, maxcolseps, black_colseps, remove_hlines, pad, mask, device, input, output) -> None: import json - from kraken import pageseg - from kraken import blla + from kraken import blla, pageseg ctx = click.get_current_context() @@ -183,8 +182,10 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, import uuid import dataclasses - from kraken.containers import Segmentation, BBoxLine from kraken import rpred + from kraken.containers import Segmentation, BBoxLine + + from kraken.lib.progress import KrakenProgressBar ctx = click.get_current_context() @@ -301,8 +302,10 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, help='Raises the exception that caused processing to fail in the case of an error') @click.option('-2', '--autocast', default=False, show_default=True, flag_value=True, help='On compatible devices, uses autocast for `segment` which lower the memory usage.') +@click.option('--threads', default=1, show_default=True, type=click.IntRange(1), + help='Size of thread pools for intra-op parallelization') def cli(input, batch_input, suffix, verbose, format_type, pdf_format, - serializer, template, device, raise_on_error, autocast): + serializer, template, device, raise_on_error, autocast, threads): """ Base command for recognition functionality. @@ -345,6 +348,8 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty import uuid import tempfile + from kraken.lib.progress import KrakenProgressBar + ctx = click.get_current_context() input = list(input) @@ -563,9 +568,7 @@ def _validate_mm(ctx, param, value): show_default=True, type=click.Choice(['horizontal-tb', 'vertical-lr', 'vertical-rl']), help='Sets principal text direction in serialization output') -@click.option('--threads', default=1, show_default=True, type=click.IntRange(1), - help='Number of threads to use for OpenMP parallelization.') -def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction, threads): +def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction): """ Recognizes text in line images. """ @@ -607,8 +610,6 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction, thr nn = defaultdict(lambda: nm['default']) # type: Dict[str, models.TorchSeqRecognizer] nn.update(nm) nm = nn - # thread count is global so setting it once is sufficient - nm[k].nn.set_num_threads(threads) ctx.meta['steps'].append({'category': 'processing', 'description': 'Text line recognition', @@ -661,6 +662,7 @@ def list_models(ctx): Lists models in the repository. """ from kraken import repo + from kraken.lib.progress import KrakenProgressBar with KrakenProgressBar() as progress: download_task = progress.add_task('Retrieving model list', total=0, visible=True if not ctx.meta['verbose'] else False) @@ -678,6 +680,7 @@ def get(ctx, model_id): Retrieves a model from the repository. """ from kraken import repo + from kraken.lib.progress import KrakenDownloadProgressBar try: os.makedirs(click.get_app_dir(APP_NAME)) From 003568bf6c5c91095b6efbcec845145f3371fa75 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 18 Sep 2023 12:30:55 +0200 Subject: [PATCH 68/68] add threadpool limits to CLI drivers --- conda/meta.yaml | 1 + environment.yml | 1 + environment_cuda.yml | 1 + kraken/ketos/pretrain.py | 11 +-- kraken/ketos/recognition.py | 132 +++++++++++++++++++---------------- kraken/ketos/ro.py | 12 ++-- kraken/ketos/segmentation.py | 99 ++++++++++++++------------ kraken/kraken.py | 6 +- setup.cfg | 1 + 9 files changed, 147 insertions(+), 117 deletions(-) diff --git a/conda/meta.yaml b/conda/meta.yaml index f68acffae..7380314bd 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -32,6 +32,7 @@ requirements: - pyarrow - pytorch-lightning~=2.0 - torchmetrics>=0.10.0 + - conda-forge::threadpoolctl~=3.2.0 - albumentations - rich about: diff --git a/environment.yml b/environment.yml index 9344af7a6..05112f876 100644 --- a/environment.yml +++ b/environment.yml @@ -24,6 +24,7 @@ dependencies: - pyarrow - conda-forge::pytorch-lightning~=2.0.0 - conda-forge::torchmetrics>=0.10.0 + - conda-forge::threadpoolctl~=3.2 - pip - albumentations - rich diff --git a/environment_cuda.yml b/environment_cuda.yml index 49c1faa70..8464004b6 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -25,6 +25,7 @@ dependencies: - pyarrow - conda-forge::pytorch-lightning~=2.0.0 - conda-forge::torchmetrics>=0.10.0 + - conda-forge::threadpoolctl~=3.2 - pip - albumentations - rich diff --git a/kraken/ketos/pretrain.py b/kraken/ketos/pretrain.py index ea5a14e19..512de415d 100644 --- a/kraken/ketos/pretrain.py +++ b/kraken/ketos/pretrain.py @@ -133,7 +133,8 @@ @click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data. Overrides the `-p` parameter') -@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads and workers when running on CPU.') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker processes.') +@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') @click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, help='When loading an existing model, retrieve hyperparameters from the model') @click.option('--repolygonize/--no-repolygonize', show_default=True, @@ -182,8 +183,8 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, min_epochs, lag, min_delta, device, precision, optimizer, lrate, momentum, weight_decay, warmup, schedule, gamma, step_size, sched_patience, cos_max, partition, fixed_splits, training_files, - evaluation_files, workers, load_hyper_parameters, repolygonize, - force_binarization, format_type, augment, + evaluation_files, workers, threads, load_hyper_parameters, repolygonize, + force_binarization, format_type, augment, mask_probability, mask_width, num_negatives, logit_temp, ground_truth): """ @@ -199,6 +200,7 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, raise click.BadOptionUsage('augment', 'augmentation needs the `albumentations` package installed.') import shutil + from threadpoolctl import threadpool_limits from kraken.lib.train import KrakenTrainer from kraken.lib.pretrain import PretrainDataModule, RecognitionPretrainModel @@ -280,7 +282,8 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs, enable_progress_bar=True if not ctx.meta['verbose'] else False, deterministic=ctx.meta['deterministic'], **val_check_interval) - trainer.fit(model, datamodule=data_module) + with threadpool_limits(limits=threads): + trainer.fit(model, datamodule=data_module) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format( diff --git a/kraken/ketos/recognition.py b/kraken/ketos/recognition.py index 04b6a30b0..050316cc2 100644 --- a/kraken/ketos/recognition.py +++ b/kraken/ketos/recognition.py @@ -23,6 +23,7 @@ import pathlib from typing import List +from threadpoolctl import threadpool_limits from kraken.lib.exceptions import KrakenInputException from kraken.lib.default_specs import RECOGNITION_HYPER_PARAMS, RECOGNITION_SPEC @@ -155,7 +156,8 @@ @click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data. Overrides the `-p` parameter') -@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads and workers when running on CPU.') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker processes.') +@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') @click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, help='When loading an existing model, retrieve hyperparameters from the model') @click.option('--repolygonize/--no-repolygonize', show_default=True, @@ -310,7 +312,8 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, log_dir=log_dir, **val_check_interval) try: - trainer.fit(model) + with threadpool_limits(limits=threads): + trainer.fit(model) except KrakenInputException as e: if e.args[0].startswith('Training data and model codec alphabets mismatch') and resize == 'fail': raise click.BadOptionUsage('resize', 'Mismatched training data for loaded model. Set option `--resize` to `new` or `add`') @@ -337,7 +340,12 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, @click.option('-d', '--device', show_default=True, default='cpu', help='Select device to use (cpu, cuda:0, cuda:1, ...)') @click.option('--pad', show_default=True, type=click.INT, default=16, help='Left and right ' 'padding around lines') -@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads when running on CPU.') +@click.option('--workers', show_default=True, default=1, + type=click.IntRange(1), + help='Number of worker processes when running on CPU.') +@click.option('--threads', show_default=True, default=1, + type=click.IntRange(1), + help='Max size of thread pools for OpenMP/BLAS operations.') @click.option('--reorder/--no-reorder', show_default=True, default=True, help='Reordering of code points to display order') @click.option('--base-dir', show_default=True, default='auto', type=click.Choice(['L', 'R', 'auto']), help='Set base text ' @@ -370,8 +378,8 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, 'collections of pre-extracted text line images.') @click.argument('test_set', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False)) def test(ctx, batch_size, model, evaluation_files, device, pad, workers, - reorder, base_dir, normalization, normalize_whitespace, repolygonize, - force_binarization, format_type, test_set): + threads, reorder, base_dir, normalization, normalize_whitespace, + repolygonize, force_binarization, format_type, test_set): """ Evaluate on a test set. """ @@ -401,8 +409,6 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, test_set = list(test_set) - # set number of OpenMP threads - next(iter(nn.values())).nn.set_num_threads(1) if evaluation_files: test_set.extend(evaluation_files) @@ -439,62 +445,64 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, reorder = base_dir acc_list = [] - for p, net in nn.items(): - algn_gt: List[str] = [] - algn_pred: List[str] = [] - chars = 0 - error = 0 - message('Evaluating {}'.format(p)) - logger.info('Evaluating {}'.format(p)) - batch, channels, height, width = net.nn.input - ts = ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm, force_binarization) - ds = DatasetClass(normalization=normalization, - whitespace_normalization=normalize_whitespace, - reorder=reorder, - im_transforms=ts) - for line in test_set: - try: - ds.add(**line) - except KrakenInputException as e: - logger.info(e) - # don't encode validation set as the alphabets may not match causing encoding failures - ds.no_encode() - ds_loader = DataLoader(ds, - batch_size=batch_size, - num_workers=workers, - pin_memory=True, - collate_fn=collate_sequences) - - with KrakenProgressBar() as progress: - batches = len(ds_loader) - pred_task = progress.add_task('Evaluating', total=batches, visible=True if not ctx.meta['verbose'] else False) - - for batch in ds_loader: - im = batch['image'] - text = batch['target'] - lens = batch['seq_lens'] + + with threadpool_limits(limits=threads): + for p, net in nn.items(): + algn_gt: List[str] = [] + algn_pred: List[str] = [] + chars = 0 + error = 0 + message('Evaluating {}'.format(p)) + logger.info('Evaluating {}'.format(p)) + batch, channels, height, width = net.nn.input + ts = ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm, force_binarization) + ds = DatasetClass(normalization=normalization, + whitespace_normalization=normalize_whitespace, + reorder=reorder, + im_transforms=ts) + for line in test_set: try: - pred = net.predict_string(im, lens) - for x, y in zip(pred, text): - chars += len(y) - c, algn1, algn2 = global_align(y, x) - algn_gt.extend(algn1) - algn_pred.extend(algn2) - error += c - except FileNotFoundError as e: - batches -= 1 - progress.update(pred_task, total=batches) - logger.warning('{} {}. Skipping.'.format(e.strerror, e.filename)) + ds.add(**line) except KrakenInputException as e: - batches -= 1 - progress.update(pred_task, total=batches) - logger.warning(str(e)) - progress.update(pred_task, advance=1) - - acc_list.append((chars - error) / chars) - confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred) - rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs) - logger.info(rep) - message(rep) + logger.info(e) + # don't encode validation set as the alphabets may not match causing encoding failures + ds.no_encode() + ds_loader = DataLoader(ds, + batch_size=batch_size, + num_workers=workers, + pin_memory=True, + collate_fn=collate_sequences) + + with KrakenProgressBar() as progress: + batches = len(ds_loader) + pred_task = progress.add_task('Evaluating', total=batches, visible=True if not ctx.meta['verbose'] else False) + + for batch in ds_loader: + im = batch['image'] + text = batch['target'] + lens = batch['seq_lens'] + try: + pred = net.predict_string(im, lens) + for x, y in zip(pred, text): + chars += len(y) + c, algn1, algn2 = global_align(y, x) + algn_gt.extend(algn1) + algn_pred.extend(algn2) + error += c + except FileNotFoundError as e: + batches -= 1 + progress.update(pred_task, total=batches) + logger.warning('{} {}. Skipping.'.format(e.strerror, e.filename)) + except KrakenInputException as e: + batches -= 1 + progress.update(pred_task, total=batches) + logger.warning(str(e)) + progress.update(pred_task, advance=1) + + acc_list.append((chars - error) / chars) + confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred) + rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs) + logger.info(rep) + message(rep) logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100)) message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100)) diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py index 17fc8cf09..006c1b8bd 100644 --- a/kraken/ketos/ro.py +++ b/kraken/ketos/ro.py @@ -124,7 +124,8 @@ @click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data. Overrides the `-p` parameter') -@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads and workers when running on CPU.') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker proesses.') +@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') @click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, help='When loading an existing model, retrieve hyper-parameters from the model') @click.option('-f', '--format-type', type=click.Choice(['xml', 'alto', 'page']), default='xml', @@ -144,13 +145,15 @@ def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag, min_delta, device, precision, optimizer, lrate, momentum, weight_decay, warmup, schedule, gamma, step_size, sched_patience, cos_max, partition, training_files, evaluation_files, workers, - load_hyper_parameters, format_type, pl_logger, log_dir, level, - reading_order, ground_truth): + threads, load_hyper_parameters, format_type, pl_logger, log_dir, + level, reading_order, ground_truth): """ Trains a baseline labeling model for layout analysis """ import shutil + from threadpoolctl import threadpool_limits + from kraken.lib.ro import ROModel from kraken.lib.train import KrakenTrainer from kraken.lib.progress import KrakenProgressBar @@ -250,7 +253,8 @@ def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag, log_dir=log_dir, **val_check_interval) - trainer.fit(model) + with threadpool_limits(limits=threads): + trainer.fit(model) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format( diff --git a/kraken/ketos/segmentation.py b/kraken/ketos/segmentation.py index e356d08df..9db1060e6 100644 --- a/kraken/ketos/segmentation.py +++ b/kraken/ketos/segmentation.py @@ -151,7 +151,8 @@ def _validate_merging(ctx, param, value): @click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data. Overrides the `-p` parameter') -@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads and workers when running on CPU.') +@click.option('--workers', show_default=True, default=1, type=click.IntRange(1), help='Number of worker proesses.') +@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.') @click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False, help='When loading an existing model, retrieve hyper-parameters from the model') @click.option('--force-binarization/--no-binarization', show_default=True, @@ -218,7 +219,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, min_epochs, lag, min_delta, device, precision, optimizer, lrate, momentum, weight_decay, warmup, schedule, gamma, step_size, sched_patience, cos_max, partition, training_files, - evaluation_files, workers, load_hyper_parameters, + evaluation_files, workers, threads, load_hyper_parameters, force_binarization, format_type, suppress_regions, suppress_baselines, valid_regions, valid_baselines, merge_regions, merge_baselines, bounding_regions, @@ -228,6 +229,8 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, """ import shutil + from threadpoolctl import threadpool_limits + from kraken.lib.train import SegmentationModel, KrakenTrainer from kraken.lib.progress import KrakenProgressBar @@ -347,7 +350,8 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, log_dir=log_dir, **val_check_interval) - trainer.fit(model) + with threadpool_limits(limits=threads): + trainer.fit(model) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format( @@ -365,7 +369,10 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, callback=_validate_manifests, type=click.File(mode='r', lazy=True), help='File(s) with paths to evaluation data.') @click.option('-d', '--device', show_default=True, default='cpu', help='Select device to use (cpu, cuda:0, cuda:1, ...)') -@click.option('--workers', show_default=True, default=1, help='Number of OpenMP threads when running on CPU.') +@click.option('--workers', default=1, show_default=True, type=click.IntRange(1), + help='Number of worker processes for data loading.') +@click.option('--threads', default=1, show_default=True, type=click.IntRange(1), + help='Size of thread pools for intra-op parallelization') @click.option('--force-binarization/--no-binarization', show_default=True, default=False, help='Forces input images to be binary, otherwise ' 'the appropriate color format will be auto-determined through the ' @@ -403,7 +410,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, @click.option("--threshold", type=click.FloatRange(.01, .99), default=.3, show_default=True, help="Threshold for heatmap binarization. Training threshold is .3, prediction is .5") @click.argument('test_set', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False)) -def segtest(ctx, model, evaluation_files, device, workers, threshold, +def segtest(ctx, model, evaluation_files, device, workers, threads, threshold, force_binarization, format_type, test_set, suppress_regions, suppress_baselines, valid_regions, valid_baselines, merge_regions, merge_baselines, bounding_regions): @@ -413,6 +420,7 @@ def segtest(ctx, model, evaluation_files, device, workers, threshold, if not model: raise click.UsageError('No model to evaluate given.') + from threadpoolctl import threadpool_limits from torch.utils.data import DataLoader import torch import torch.nn.functional as F @@ -502,46 +510,47 @@ def segtest(ctx, model, evaluation_files, device, workers, threshold, with KrakenProgressBar() as progress: batches = len(ds_loader) pred_task = progress.add_task('Evaluating', total=batches, visible=True if not ctx.meta['verbose'] else False) - for batch in ds_loader: - x, y = batch['image'], batch['target'] - try: - pred, _ = nn.nn(x) - # scale target to output size - y = F.interpolate(y, size=(pred.size(2), pred.size(3))).squeeze(0).bool() - pred = pred.squeeze() > threshold - pred = pred.view(pred.size(0), -1) - y = y.view(y.size(0), -1) - pages.append({ - 'intersections': (y & pred).sum(dim=1, dtype=torch.double), - 'unions': (y | pred).sum(dim=1, dtype=torch.double), - 'corrects': torch.eq(y, pred).sum(dim=1, dtype=torch.double), - 'cls_cnt': y.sum(dim=1, dtype=torch.double), - 'all_n': torch.tensor(y.size(1), dtype=torch.double, device=device) - }) - if lines_idx: - y_baselines = y[lines_idx].sum(dim=0, dtype=torch.bool) - pred_baselines = pred[lines_idx].sum(dim=0, dtype=torch.bool) - pages[-1]["baselines"] = { - 'intersections': (y_baselines & pred_baselines).sum(dim=0, dtype=torch.double), - 'unions': (y_baselines | pred_baselines).sum(dim=0, dtype=torch.double), - } - if regions_idx: - y_regions_idx = y[regions_idx].sum(dim=0, dtype=torch.bool) - pred_regions_idx = pred[regions_idx].sum(dim=0, dtype=torch.bool) - pages[-1]["regions"] = { - 'intersections': (y_regions_idx & pred_regions_idx).sum(dim=0, dtype=torch.double), - 'unions': (y_regions_idx | pred_regions_idx).sum(dim=0, dtype=torch.double), - } - - except FileNotFoundError as e: - batches -= 1 - progress.update(pred_task, total=batches) - logger.warning('{} {}. Skipping.'.format(e.strerror, e.filename)) - except KrakenInputException as e: - batches -= 1 - progress.update(pred_task, total=batches) - logger.warning(str(e)) - progress.update(pred_task, advance=1) + with threadpool_limits(limits=threads): + for batch in ds_loader: + x, y = batch['image'], batch['target'] + try: + pred, _ = nn.nn(x) + # scale target to output size + y = F.interpolate(y, size=(pred.size(2), pred.size(3))).squeeze(0).bool() + pred = pred.squeeze() > threshold + pred = pred.view(pred.size(0), -1) + y = y.view(y.size(0), -1) + pages.append({ + 'intersections': (y & pred).sum(dim=1, dtype=torch.double), + 'unions': (y | pred).sum(dim=1, dtype=torch.double), + 'corrects': torch.eq(y, pred).sum(dim=1, dtype=torch.double), + 'cls_cnt': y.sum(dim=1, dtype=torch.double), + 'all_n': torch.tensor(y.size(1), dtype=torch.double, device=device) + }) + if lines_idx: + y_baselines = y[lines_idx].sum(dim=0, dtype=torch.bool) + pred_baselines = pred[lines_idx].sum(dim=0, dtype=torch.bool) + pages[-1]["baselines"] = { + 'intersections': (y_baselines & pred_baselines).sum(dim=0, dtype=torch.double), + 'unions': (y_baselines | pred_baselines).sum(dim=0, dtype=torch.double), + } + if regions_idx: + y_regions_idx = y[regions_idx].sum(dim=0, dtype=torch.bool) + pred_regions_idx = pred[regions_idx].sum(dim=0, dtype=torch.bool) + pages[-1]["regions"] = { + 'intersections': (y_regions_idx & pred_regions_idx).sum(dim=0, dtype=torch.double), + 'unions': (y_regions_idx | pred_regions_idx).sum(dim=0, dtype=torch.double), + } + + except FileNotFoundError as e: + batches -= 1 + progress.update(pred_task, total=batches) + logger.warning('{} {}. Skipping.'.format(e.strerror, e.filename)) + except KrakenInputException as e: + batches -= 1 + progress.update(pred_task, total=batches) + logger.warning(str(e)) + progress.update(pred_task, advance=1) # Accuracy / pixel corrects = torch.stack([x['corrects'] for x in pages], -1).sum(dim=-1) diff --git a/kraken/kraken.py b/kraken/kraken.py index 38647eee8..27a0f5cef 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -28,7 +28,6 @@ from pathlib import Path from functools import partial from rich.traceback import install -from threadpoolctl import threadpool_limits from typing import Dict, Union, List, cast, Any, IO, Callable import click @@ -335,6 +334,7 @@ def cli(input, batch_input, suffix, verbose, format_type, pdf_format, ctx.meta['verbose'] = verbose ctx.meta['steps'] = [] ctx.meta["autocast"] = autocast + ctx.meta['threads'] = threads log.set_logger(logger, level=30 - min(10 * verbose, 20)) @@ -348,6 +348,7 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty import uuid import tempfile + from threadpoolctl import threadpool_limits from kraken.lib.progress import KrakenProgressBar ctx = click.get_current_context() @@ -414,7 +415,8 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty for idx, (task, input, output) in enumerate(zip(subcommands, fc, fc[1:])): if len(fc) - 2 == idx: ctx.meta['last_process'] = True - task(input=input, output=output) + with threadpool_limits(limits=ctx.meta['threads']): + task(input=input, output=output) except Exception as e: logger.error(f'Failed processing {io_pair[0]}: {str(e)}') if ctx.meta['raise_failed']: diff --git a/setup.cfg b/setup.cfg index 60c3994d6..f26ab7957 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,6 +60,7 @@ install_requires = pyarrow pytorch-lightning~=2.0.0 torchmetrics>=0.10.0 + threadpoolctl~=3.2.0 rich [options.extras_require]