Skip to content

Commit

Permalink
Add multi-model capability to CLI driver
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Jul 29, 2024
1 parent c5b870b commit 4219c46
Showing 1 changed file with 33 additions and 27 deletions.
60 changes: 33 additions & 27 deletions kraken/kraken.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def binarizer(threshold, zoom, escale, border, perc, range, low, high, input, ou
message('\u2713', fg='green')


def segmenter(legacy, model, text_direction, scale, maxcolseps, black_colseps,
def segmenter(legacy, models, text_direction, scale, maxcolseps, black_colseps,
remove_hlines, pad, mask, device, input, output) -> None:
import json

Expand Down Expand Up @@ -151,7 +151,7 @@ def segmenter(legacy, model, text_direction, scale, maxcolseps, black_colseps,
pad=pad,
mask=mask)
else:
res = blla.segment(im, text_direction, mask=mask, model=model, device=device,
res = blla.segment(im, text_direction, mask=mask, model=models, device=device,
raise_on_error=ctx.meta['raise_failed'], autocast=ctx.meta["autocast"])
except Exception:
if ctx.meta['raise_failed']:
Expand Down Expand Up @@ -474,7 +474,7 @@ def binarize(ctx, threshold, zoom, escale, border, perc, range, low, high):

@cli.command('segment')
@click.pass_context
@click.option('-i', '--model', default=None, show_default=True,
@click.option('-i', '--model', default=None, show_default=True, multiple=True,
help='Baseline detection model to use')
@click.option('-x/-bl', '--boxes/--baseline', default=True, show_default=True,
help='Switch between legacy box segmenter and neural baseline segmenter')
Expand All @@ -501,42 +501,48 @@ def segment(ctx, model, boxes, text_direction, scale, maxcolseps,
"""
from kraken.containers import ProcessingStep

print(model)

if model and boxes:
logger.warning(f'Baseline model ({model}) given but legacy segmenter selected. Forcing to -bl.')
boxes = False

if boxes is False:
if not model:
model = SEGMENTATION_DEFAULT_MODEL
model = [SEGMENTATION_DEFAULT_MODEL]
ctx.meta['steps'].append(ProcessingStep(id=str(uuid.uuid4()),
category='processing',
description='Baseline and region segmentation',
settings={'model': os.path.basename(model),
settings={'model': [os.path.basename(m) for m in model],
'text_direction': text_direction}))

# first try to find the segmentation model by its given name,
# then look in the kraken config folder
location = None
search = [model, os.path.join(click.get_app_dir(APP_NAME), model)]
for loc in search:
if os.path.isfile(loc):
location = loc
break
if not location:
raise click.BadParameter(f'No model for {model} found')
# first try to find the segmentation models by their given names, then
# look in the kraken config folder
locations = []
for m in model:
location = None
search = [m, os.path.join(click.get_app_dir(APP_NAME), m)]
for loc in search:
if os.path.isfile(loc):
location = loc
locations.append(loc)
break
if not location:
raise click.BadParameter(f'No model for {m} found')

from kraken.lib.vgsl import TorchVGSLModel
message(f'Loading ANN {model}\t', nl=False)
try:
model = TorchVGSLModel.load_model(location)
model.to(ctx.meta['device'])
except Exception:
if ctx.meta['raise_failed']:
raise
message('\u2717', fg='red')
ctx.exit(1)

message('\u2713', fg='green')
from kraken.lib.vgsl import TorchVGSLModel
models = []
for loc in locations:
message(f'Loading ANN {loc}\t', nl=False)
try:
models.append(TorchVGSLModel.load_model(loc).to(ctx.meta['device']))
except Exception:
if ctx.meta['raise_failed']:
raise
message('\u2717', fg='red')
ctx.exit(1)
message('\u2713', fg='green')
else:
ctx.meta['steps'].append(ProcessingStep(id=str(uuid.uuid4()),
category='processing',
Expand All @@ -548,7 +554,7 @@ def segment(ctx, model, boxes, text_direction, scale, maxcolseps,
'remove_hlines': remove_hlines,
'pad': pad}))

return partial(segmenter, boxes, model, text_direction, scale, maxcolseps,
return partial(segmenter, boxes, models, text_direction, scale, maxcolseps,
black_colseps, remove_hlines, pad, mask, ctx.meta['device'])


Expand Down

0 comments on commit 4219c46

Please sign in to comment.