Skip to content

Commit

Permalink
More test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Jan 7, 2024
1 parent ca2e77f commit 8ff27d1
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 124 deletions.
7 changes: 2 additions & 5 deletions kraken/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class BBoxLine:
Attributes:
id: Unique identifier
bbox: Tuple in form `((x0, y0), (x1, y0), (x1, y1), (x0, y1))` defining
bbox: Tuple in form `(xmin, ymin, xmax, ymax)` defining
the bounding box.
text: Transcription of this line.
base_dir: An optional string defining the base direction (also called
Expand All @@ -120,10 +120,7 @@ class BBoxLine:
reading direction (of the document).
"""
id: str
bbox: Tuple[Tuple[int, int],
Tuple[int, int],
Tuple[int, int],
Tuple[int, int]]
bbox: Tuple[int, int, int, int]
text: Optional[str] = None
base_dir: Optional[Literal['L', 'R']] = None
type: str = 'bbox'
Expand Down
3 changes: 2 additions & 1 deletion kraken/lib/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,8 @@ def extract_polygons(im: Image.Image, bounds: 'Segmentation') -> Image.Image:
angle = 90
else:
angle = 0
for box in bounds.lines:
for line in bounds.lines:
box = line.bbox
if isinstance(box, tuple):
box = list(box)
if (box < [0, 0, 0, 0] or box[::2] >= [im.size[0], im.size[0]] or
Expand Down
96 changes: 54 additions & 42 deletions kraken/rpred.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,18 @@ def __init__(self,
if not tags_ignore:
tags_ignore = []

if bounds.script_detection:
self.have_tags = True
else:
self.have_tags = False

if bounds.type not in seg_types or len(seg_types) > 1:
logger.warning(f'Recognizers with segmentation types {seg_types} will be '
f'applied to segmentation of type {bounds.type}. '
f'This will likely result in severely degraded performace')
one_channel_modes = set(recognizer.nn.one_channel_mode for recognizer in nets.values())
if '1' in one_channel_modes and len(one_channel_modes) > 1:
raise KrakenInputException('Mixing binary and non-binary recognition models is not supported.')
raise ValueError('Mixing binary and non-binary recognition models is not supported.')
elif '1' in one_channel_modes and not is_bitonal(im):
logger.warning('Running binary models on non-binary input image '
f'(mode {im.mode}). This will result in severely degraded '
Expand All @@ -113,32 +118,39 @@ def __init__(self,
valid_norm = True
self.next_iter = self._recognize_box_line

tags = set()
for x in bounds.lines:
tags.update(x.tags.items())

im_str = get_im_str(im)
logger.info(f'Running {len(nets)} multi-script recognizers on {im_str} with {self.len} lines')

filtered_tags = []
miss = []
for tag in tags:
if not isinstance(nets, defaultdict) and (not nets.get(tag) and tag not in tags_ignore):
miss.append(tag)
elif tag not in tags_ignore:
filtered_tags.append(tag)
tags = filtered_tags

if miss:
raise KrakenInputException(f'Missing models for tags {set(miss)}')

# build dictionary for line preprocessing
self.ts = {}
for tag in tags:
logger.debug(f'Loading line transforms for {tag}')
network = nets[tag]
if self.have_tags:
tags = set()
for x in bounds.lines:
tags.update(x.tags.items())

im_str = get_im_str(im)
logger.info(f'Running {len(nets)} multi-script recognizers on {im_str} with {self.len} lines')

filtered_tags = []
miss = []
for tag in tags:
if not isinstance(nets, defaultdict) and (not nets.get(tag) and tag not in tags_ignore):
miss.append(tag)
elif tag not in tags_ignore:
filtered_tags.append(tag)
tags = filtered_tags

if miss:
raise KrakenInputException(f'Missing models for tags {set(miss)}')

# build dictionary for line preprocessing
self.ts = {}
for tag in tags:
logger.debug(f'Loading line transforms for {tag}')
network = nets[tag]
batch, channels, height, width = network.nn.input
self.ts[tag] = ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm)
elif isinstance(nets, defaultdict) and nets.default_factory:
network = nets.default_factory()
batch, channels, height, width = network.nn.input
self.ts[tag] = ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm)
self.ts = {('type', 'default'): ImageInputTransforms(batch, height, width, channels, (pad, 0), valid_norm)}
else:
raise ValueError('No tags in input data and no default model in mapping given.')

self.im = im
self.nets = nets
Expand All @@ -148,23 +160,22 @@ def __init__(self,
self.tags_ignore = tags_ignore

def _recognize_box_line(self, line):
flat_box = [point for box in line.bbox for point in box]
xmin, xmax = min(flat_box[::2]), max(flat_box[::2])
ymin, ymax = min(flat_box[1::2]), max(flat_box[1::2])
xmin, ymin, xmax, ymax = line.bbox
prediction = ''
cuts = []
confidences = []
line.text_direction = self.bounds.text_direction

if self.tags_ignore is not None:
for tag in line.tags.values():
if self.have_tags and self.tags_ignore:
for tag in line.tags.items():
if tag in self.tags_ignore:
logger.info(f'Ignoring line segment with tags {line.tags} based on {tag}.')
return BaselineOCRRecord('', [], [], line)
return BBoxOCRRecord('', (), (), line)

tag, net = self._resolve_tags_to_model(line.tags, self.nets)

box, coords = next(extract_polygons(self.im, line))
seg = dataclasses.replace(self.bounds, lines=[line])
box, coords = next(extract_polygons(self.im, seg))
self.box = box

# check if boxes are non-zero in any dimension
Expand All @@ -184,8 +195,6 @@ def _recognize_box_line(self, line):
logger.warning('Empty run. Emitting empty record.')
return BBoxOCRRecord('', (), (), line)

_, net = self._resolve_tags_to_model({'type': tag}, self.nets)

logger.debug(f'Forward pass with model {tag}.')
preds = net.predict(ts_box.unsqueeze(0))[0]

Expand Down Expand Up @@ -224,8 +233,8 @@ def _recognize_box_line(self, line):
return rec.display_order(None)

def _recognize_baseline_line(self, line):
if self.tags_ignore is not None:
for tag in line.tags.values():
if self.have_tags and self.tags_ignore is not None:
for tag in line.tags.items():
if tag in self.tags_ignore:
logger.info(f'Ignoring line segment with tags {line.tags} based on {tag}.')
return BaselineOCRRecord('', [], [], line)
Expand Down Expand Up @@ -319,15 +328,18 @@ def rpred(network: 'TorchSeqRecognizer',
return mm_rpred(defaultdict(lambda: network), im, bounds, pad, bidi_reordering)


def _resolve_tags_to_model(tags: Sequence[Dict[str, str]],
def _resolve_tags_to_model(tags: Optional[Sequence[Dict[str, str]]],
model_map: Dict[Tuple[str, str], 'TorchSeqRecognizer'],
default: Optional['TorchSeqRecognizer'] = None) -> 'TorchSeqRecognizer':
"""
Resolves a sequence of tags
"""
for tag in tags.items():
if tag in model_map:
return tag, model_map[tag]
if default:
if not tags and default:
return ('type', 'default'), default
elif tags:
for tag in tags.items():
if tag in model_map:
return tag, model_map[tag]
elif tags and default:
return next(tags.values()), default
raise KrakenInputException(f'No model for tags {tags}')
7 changes: 5 additions & 2 deletions kraken/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,14 @@ def serialize(results: 'Segmentation',
# set field to indicate the availability of baseline segmentation in
# addition to bounding boxes
line = {'index': idx,
'bbox': max_bbox([record.boundary] if record.type == 'baselines' else [record.bbox]),
'bbox': max_bbox([record.boundary]) if record.type == 'baselines' else record.bbox,
'cuts': record.cuts,
'confidences': record.confidences,
'recognition': [],
'boundary': [list(x) for x in record.boundary] if record.type == 'baselines' else record.bbox,
'boundary': [list(x) for x in record.boundary] if record.type == 'baselines' else [[record.bbox[0], record.bbox[1]],
[record.bbox[2], record.bbox[1]],
[record.bbox[2], record.bbox[3]],
[record.bbox[0], record.bbox[3]]],
'type': 'line'
}
if record.tags is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/resources/records.json

Large diffs are not rendered by default.

Loading

0 comments on commit 8ff27d1

Please sign in to comment.