diff --git a/kraken/ketos/recognition.py b/kraken/ketos/recognition.py index ba324e810..04b6a30b0 100644 --- a/kraken/ketos/recognition.py +++ b/kraken/ketos/recognition.py @@ -24,7 +24,6 @@ from typing import List -from kraken.lib.progress import KrakenProgressBar from kraken.lib.exceptions import KrakenInputException from kraken.lib.default_specs import RECOGNITION_HYPER_PARAMS, RECOGNITION_SPEC from .util import _validate_manifests, _expand_gt, message, to_ptl_device @@ -390,6 +389,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, ImageInputTransforms, ArrowIPCRecognitionDataset, collate_sequences) + from kraken.lib.progress import KrakenProgressBar logger.info('Building test set from {} line images'.format(len(test_set) + len(evaluation_files))) diff --git a/kraken/ketos/repo.py b/kraken/ketos/repo.py index 32a49ac5e..52c9db8e3 100644 --- a/kraken/ketos/repo.py +++ b/kraken/ketos/repo.py @@ -22,8 +22,6 @@ import click import logging -from kraken.lib.progress import KrakenDownloadProgressBar - from .util import message logging.captureWarnings(True) @@ -52,6 +50,7 @@ def publish(ctx, metadata, access_token, private, model): from kraken import repo from kraken.lib import models + from kraken.lib.progress import KrakenDownloadProgressBar with pkg_resources.resource_stream('kraken', 'metadata.schema.json') as fp: schema = json.load(fp) diff --git a/kraken/ketos/ro.py b/kraken/ketos/ro.py index 1dcc0856f..17fc8cf09 100644 --- a/kraken/ketos/ro.py +++ b/kraken/ketos/ro.py @@ -25,7 +25,6 @@ from PIL import Image from typing import Dict -from kraken.lib.progress import KrakenProgressBar from kraken.lib.exceptions import KrakenInputException from kraken.lib.default_specs import READING_ORDER_HYPER_PARAMS @@ -152,8 +151,9 @@ def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag, """ import shutil - from kraken.lib.train import KrakenTrainer from kraken.lib.ro import ROModel + from kraken.lib.train import KrakenTrainer + from kraken.lib.progress import KrakenProgressBar if not (0 <= freq <= 1) and freq % 1.0 != 0: raise click.BadOptionUsage('freq', 'freq needs to be either in the interval [0,1.0] or a positive integer.') diff --git a/kraken/ketos/segmentation.py b/kraken/ketos/segmentation.py index dfcb7c739..e356d08df 100644 --- a/kraken/ketos/segmentation.py +++ b/kraken/ketos/segmentation.py @@ -24,7 +24,6 @@ from PIL import Image -from kraken.lib.progress import KrakenProgressBar from kraken.lib.exceptions import KrakenInputException from kraken.lib.default_specs import SEGMENTATION_HYPER_PARAMS, SEGMENTATION_SPEC @@ -230,6 +229,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs, import shutil from kraken.lib.train import SegmentationModel, KrakenTrainer + from kraken.lib.progress import KrakenProgressBar if resize != 'fail' and not load: raise click.BadOptionUsage('resize', 'resize option requires loading an existing model') diff --git a/kraken/ketos/transcription.py b/kraken/ketos/transcription.py index 490c0ac4e..dd4402c66 100644 --- a/kraken/ketos/transcription.py +++ b/kraken/ketos/transcription.py @@ -27,7 +27,6 @@ from typing import IO, Any, cast from bidi.algorithm import get_display -from kraken.lib.progress import KrakenProgressBar from .util import message logging.captureWarnings(True) @@ -68,6 +67,7 @@ def extract(ctx, binarize, normalization, normalize_whitespace, reorder, from lxml import html, etree from kraken import binarization + from kraken.lib.progress import KrakenProgressBar try: os.mkdir(output) @@ -172,6 +172,7 @@ def transcription(ctx, text_direction, scale, bw, maxcolseps, from kraken import binarization from kraken.lib import models + from kraken.lib.progress import KrakenProgressBar ti = transcribe.TranscriptionInterface(font, font_style) diff --git a/kraken/kraken.py b/kraken/kraken.py index 0f8fea361..38647eee8 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -24,16 +24,16 @@ import dataclasses import pkg_resources -from typing import Dict, Union, List, cast, Any, IO, Callable +from PIL import Image from pathlib import Path -from rich.traceback import install from functools import partial -from PIL import Image +from rich.traceback import install +from threadpoolctl import threadpool_limits +from typing import Dict, Union, List, cast, Any, IO, Callable import click from kraken.lib import log -from kraken.lib.progress import KrakenProgressBar, KrakenDownloadProgressBar warnings.simplefilter('ignore', UserWarning) @@ -118,8 +118,7 @@ def segmenter(legacy, model, text_direction, scale, maxcolseps, black_colseps, remove_hlines, pad, mask, device, input, output) -> None: import json - from kraken import pageseg - from kraken import blla + from kraken import blla, pageseg ctx = click.get_current_context() @@ -183,8 +182,10 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, import uuid import dataclasses - from kraken.containers import Segmentation, BBoxLine from kraken import rpred + from kraken.containers import Segmentation, BBoxLine + + from kraken.lib.progress import KrakenProgressBar ctx = click.get_current_context() @@ -301,8 +302,10 @@ def recognizer(model, pad, no_segmentation, bidi_reordering, tags_ignore, input, help='Raises the exception that caused processing to fail in the case of an error') @click.option('-2', '--autocast', default=False, show_default=True, flag_value=True, help='On compatible devices, uses autocast for `segment` which lower the memory usage.') +@click.option('--threads', default=1, show_default=True, type=click.IntRange(1), + help='Size of thread pools for intra-op parallelization') def cli(input, batch_input, suffix, verbose, format_type, pdf_format, - serializer, template, device, raise_on_error, autocast): + serializer, template, device, raise_on_error, autocast, threads): """ Base command for recognition functionality. @@ -345,6 +348,8 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty import uuid import tempfile + from kraken.lib.progress import KrakenProgressBar + ctx = click.get_current_context() input = list(input) @@ -563,9 +568,7 @@ def _validate_mm(ctx, param, value): show_default=True, type=click.Choice(['horizontal-tb', 'vertical-lr', 'vertical-rl']), help='Sets principal text direction in serialization output') -@click.option('--threads', default=1, show_default=True, type=click.IntRange(1), - help='Number of threads to use for OpenMP parallelization.') -def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction, threads): +def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction): """ Recognizes text in line images. """ @@ -607,8 +610,6 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction, thr nn = defaultdict(lambda: nm['default']) # type: Dict[str, models.TorchSeqRecognizer] nn.update(nm) nm = nn - # thread count is global so setting it once is sufficient - nm[k].nn.set_num_threads(threads) ctx.meta['steps'].append({'category': 'processing', 'description': 'Text line recognition', @@ -661,6 +662,7 @@ def list_models(ctx): Lists models in the repository. """ from kraken import repo + from kraken.lib.progress import KrakenProgressBar with KrakenProgressBar() as progress: download_task = progress.add_task('Retrieving model list', total=0, visible=True if not ctx.meta['verbose'] else False) @@ -678,6 +680,7 @@ def get(ctx, model_id): Retrieves a model from the repository. """ from kraken import repo + from kraken.lib.progress import KrakenDownloadProgressBar try: os.makedirs(click.get_app_dir(APP_NAME))