Skip to content

Commit

Permalink
Fix segtrain regression
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Jan 11, 2024
1 parent e80308b commit 6df5d31
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
3 changes: 1 addition & 2 deletions kraken/lib/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(self,
self.transforms = im_transforms
self.seg_type = None

def add(self, doc: Union['Segmentation']):
def add(self, doc: 'Segmentation'):
"""
Adds a page to the dataset.
Expand Down Expand Up @@ -154,7 +154,6 @@ def add(self, doc: Union['Segmentation']):
if reg_type not in self.class_mapping['regions']:
self.num_classes += 1
self.class_mapping['regions'][reg_type] = self.num_classes - 1

self.targets.append({'baselines': baselines_, 'regions': regions_})
self.imgs.append(doc.imagename)

Expand Down
27 changes: 16 additions & 11 deletions kraken/lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,6 @@ def __init__(self,
warnings.warn("'both' value for resize has been deprecated. Use 'new' instead.", DeprecationWarning)
self.resize = resize

self.format_type = format_type
self.output = output
self.bounding_regions = bounding_regions
self.topline = topline
Expand Down Expand Up @@ -782,6 +781,17 @@ def __init__(self,
self.hyper_params = hyper_params_
self.save_hyperparameters()

if format_type in ['xml', 'page', 'alto']:
logger.info(f'Parsing {len(training_data)} XML files for training data')
training_data = [XMLPage(file, format_type).to_container() for file in training_data]
if evaluation_data:
logger.info(f'Parsing {len(evaluation_data)} XML files for validation data')
evaluation_data = [XMLPage(file, format_type).to_container() for file in evaluation_data]
elif not format_type:
pass
else:
raise ValueError(f'format_type {format_type} not in [alto, page, xml, None].')

if not training_data:
raise ValueError('No training data provided. Please add some.')

Expand Down Expand Up @@ -815,34 +825,29 @@ def __init__(self,
valid_baselines = []
merge_baselines = None

train_set = BaselineSet(training_data,
line_width=self.hparams.hyper_params['line_width'],
train_set = BaselineSet(line_width=self.hparams.hyper_params['line_width'],
im_transforms=transforms,
mode=format_type,
augmentation=self.hparams.hyper_params['augment'],
valid_baselines=valid_baselines,
merge_baselines=merge_baselines,
valid_regions=valid_regions,
merge_regions=merge_regions)

if format_type is None:
for page in training_data:
train_set.add(**page)
for page in training_data:
train_set.add(page)

if evaluation_data:
val_set = BaselineSet(evaluation_data,
line_width=self.hparams.hyper_params['line_width'],
im_transforms=transforms,
mode=format_type,
augmentation=False,
valid_baselines=valid_baselines,
merge_baselines=merge_baselines,
valid_regions=valid_regions,
merge_regions=merge_regions)

if format_type is None:
for page in evaluation_data:
val_set.add(**page)
for page in evaluation_data:
val_set.add(page)

train_set = Subset(train_set, range(len(train_set)))
val_set = Subset(val_set, range(len(val_set)))
Expand Down
5 changes: 3 additions & 2 deletions kraken/lib/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ def _parse_alto(self):
# parse region type and coords
region_data = defaultdict(list)
for region in regions:
region_id = region.get('ID')
# try to find shape object
coords = region.find('./{*}Shape/{*}Polygon')
boundary = None
if coords is not None:
boundary = self._parse_alto_pointstype(coords.get('POINTS'))
elif (region.get('HPOS') is not None and region.get('VPOS') is not None and
Expand All @@ -169,8 +171,7 @@ def _parse_alto(self):
break
if rtype is None:
rtype = alto_regions[region.tag.split('}')[-1]]
region_id = region.get('ID')
region_data[rtype].append(Region(id=region_id, boundary=coords, tags={'type': rtype}))
region_data[rtype].append(Region(id=region_id, boundary=boundary, tags={'type': rtype}))
# register implicit reading order
self._orders['region_implicit']['order'].append(region_id)

Expand Down

0 comments on commit 6df5d31

Please sign in to comment.