From 4219c464598dc358e2bf9ea7731918e921ed1087 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 29 Jul 2024 12:51:13 +0200 Subject: [PATCH] Add multi-model capability to CLI driver --- kraken/kraken.py | 60 ++++++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/kraken/kraken.py b/kraken/kraken.py index 22429d05d..c4b854f5d 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -114,7 +114,7 @@ def binarizer(threshold, zoom, escale, border, perc, range, low, high, input, ou message('\u2713', fg='green') -def segmenter(legacy, model, text_direction, scale, maxcolseps, black_colseps, +def segmenter(legacy, models, text_direction, scale, maxcolseps, black_colseps, remove_hlines, pad, mask, device, input, output) -> None: import json @@ -151,7 +151,7 @@ def segmenter(legacy, model, text_direction, scale, maxcolseps, black_colseps, pad=pad, mask=mask) else: - res = blla.segment(im, text_direction, mask=mask, model=model, device=device, + res = blla.segment(im, text_direction, mask=mask, model=models, device=device, raise_on_error=ctx.meta['raise_failed'], autocast=ctx.meta["autocast"]) except Exception: if ctx.meta['raise_failed']: @@ -474,7 +474,7 @@ def binarize(ctx, threshold, zoom, escale, border, perc, range, low, high): @cli.command('segment') @click.pass_context -@click.option('-i', '--model', default=None, show_default=True, +@click.option('-i', '--model', default=None, show_default=True, multiple=True, help='Baseline detection model to use') @click.option('-x/-bl', '--boxes/--baseline', default=True, show_default=True, help='Switch between legacy box segmenter and neural baseline segmenter') @@ -501,42 +501,48 @@ def segment(ctx, model, boxes, text_direction, scale, maxcolseps, """ from kraken.containers import ProcessingStep + print(model) + if model and boxes: logger.warning(f'Baseline model ({model}) given but legacy segmenter selected. Forcing to -bl.') boxes = False if boxes is False: if not model: - model = SEGMENTATION_DEFAULT_MODEL + model = [SEGMENTATION_DEFAULT_MODEL] ctx.meta['steps'].append(ProcessingStep(id=str(uuid.uuid4()), category='processing', description='Baseline and region segmentation', - settings={'model': os.path.basename(model), + settings={'model': [os.path.basename(m) for m in model], 'text_direction': text_direction})) - # first try to find the segmentation model by its given name, - # then look in the kraken config folder - location = None - search = [model, os.path.join(click.get_app_dir(APP_NAME), model)] - for loc in search: - if os.path.isfile(loc): - location = loc - break - if not location: - raise click.BadParameter(f'No model for {model} found') + # first try to find the segmentation models by their given names, then + # look in the kraken config folder + locations = [] + for m in model: + location = None + search = [m, os.path.join(click.get_app_dir(APP_NAME), m)] + for loc in search: + if os.path.isfile(loc): + location = loc + locations.append(loc) + break + if not location: + raise click.BadParameter(f'No model for {m} found') - from kraken.lib.vgsl import TorchVGSLModel - message(f'Loading ANN {model}\t', nl=False) - try: - model = TorchVGSLModel.load_model(location) - model.to(ctx.meta['device']) - except Exception: - if ctx.meta['raise_failed']: - raise - message('\u2717', fg='red') - ctx.exit(1) - message('\u2713', fg='green') + from kraken.lib.vgsl import TorchVGSLModel + models = [] + for loc in locations: + message(f'Loading ANN {loc}\t', nl=False) + try: + models.append(TorchVGSLModel.load_model(loc).to(ctx.meta['device'])) + except Exception: + if ctx.meta['raise_failed']: + raise + message('\u2717', fg='red') + ctx.exit(1) + message('\u2713', fg='green') else: ctx.meta['steps'].append(ProcessingStep(id=str(uuid.uuid4()), category='processing', @@ -548,7 +554,7 @@ def segment(ctx, model, boxes, text_direction, scale, maxcolseps, 'remove_hlines': remove_hlines, 'pad': pad})) - return partial(segmenter, boxes, model, text_direction, scale, maxcolseps, + return partial(segmenter, boxes, models, text_direction, scale, maxcolseps, black_colseps, remove_hlines, pad, mask, ctx.meta['device'])