From 36b3424810c8eaec2ba62e784d579eb3d2a716bb Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sat, 20 Apr 2024 15:15:33 +0200 Subject: [PATCH 1/3] Update pytorch/lightning to 2.2 --- conda/meta.yaml | 6 +- environment.yml | 6 +- environment_cuda.yml | 6 +- .../hyperparameters/tune_pretraining.py | 89 ------------------- .../contrib/hyperparameters/tune_training.py | 58 ------------ kraken/ketos/__init__.py | 4 +- kraken/lib/pretrain/model.py | 10 +-- kraken/lib/progress.py | 2 +- kraken/lib/ro/model.py | 6 +- kraken/lib/train.py | 32 +++---- setup.cfg | 6 +- 11 files changed, 39 insertions(+), 186 deletions(-) delete mode 100644 kraken/contrib/hyperparameters/tune_pretraining.py delete mode 100644 kraken/contrib/hyperparameters/tune_training.py diff --git a/conda/meta.yaml b/conda/meta.yaml index 1e687f559..4a06a0a91 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -21,7 +21,7 @@ requirements: - scipy~=1.11.0 - jinja2~=3.0 - torchvision - - pytorch>=1.12.0 + - pytorch~=2.2.0 - cudatoolkit - jsonschema - scikit-image~=0.21.0 @@ -30,9 +30,9 @@ requirements: - pyvips - coremltools - pyarrow - - lightning~=2.0 + - lightning~=2.2 - torchmetrics>=1.1.0 - - conda-forge::threadpoolctl~=3.2.0 + - conda-forge::threadpoolctl~=3.4.0 - albumentations - rich about: diff --git a/environment.yml b/environment.yml index 410e1a087..4787271dc 100644 --- a/environment.yml +++ b/environment.yml @@ -14,7 +14,7 @@ dependencies: - scipy~=1.10.0 - jinja2~=3.0 - conda-forge::torchvision-cpu>=0.5.0 - - conda-forge::pytorch-cpu~=2.0.0 + - conda-forge::pytorch-cpu~=2.2.0 - jsonschema - scikit-learn~=1.2.1 - scikit-image~=0.21.0 @@ -23,9 +23,9 @@ dependencies: - imagemagick>=7.1.0 - pyarrow - importlib-resources>=1.3.0 - - conda-forge::lightning~=2.0.0 + - conda-forge::lightning~=2.2.0 - conda-forge::torchmetrics>=1.1.0 - - conda-forge::threadpoolctl~=3.2 + - conda-forge::threadpoolctl~=3.4 - pip - albumentations - rich diff --git a/environment_cuda.yml b/environment_cuda.yml index 7844ba138..ebdda92de 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -14,7 +14,7 @@ dependencies: - scipy~=1.10.0 - jinja2~=3.0 - conda-forge::torchvision>=0.5.0 - - conda-forge::pytorch~=2.0.0 + - conda-forge::pytorch~=2.2.0 - cudatoolkit>=9.2 - jsonschema - scikit-learn~=1.2.1 @@ -24,9 +24,9 @@ dependencies: - imagemagick>=7.1.0 - pyarrow - importlib-resources>=1.3.0 - - conda-forge::lightning~=2.0.0 + - conda-forge::lightning~=2.2.0 - conda-forge::torchmetrics>=1.1.0 - - conda-forge::threadpoolctl~=3.2 + - conda-forge::threadpoolctl~=3.4 - pip - albumentations - rich diff --git a/kraken/contrib/hyperparameters/tune_pretraining.py b/kraken/contrib/hyperparameters/tune_pretraining.py deleted file mode 100644 index 5564b521d..000000000 --- a/kraken/contrib/hyperparameters/tune_pretraining.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python -""" -A script for a grid search over pretraining hyperparameters. -""" -from functools import partial - -import click -import pytorch_lightning as pl -from pytorch_lightning import seed_everything -from ray import tune -from ray.tune.integration.pytorch_lightning import TuneReportCallback - -from kraken.ketos.util import _validate_manifests -from kraken.lib.default_specs import (RECOGNITION_PRETRAIN_HYPER_PARAMS, - RECOGNITION_SPEC) -from kraken.lib.pretrain.model import (PretrainDataModule, - RecognitionPretrainModel) - -config = {'lrate': tune.loguniform(1e-8, 1e-2), - 'num_negatives': tune.qrandint(1, 4, 1), - 'mask_prob': tune.loguniform(0.01, 0.2), - 'mask_width': tune.qrandint(2, 8, 2)} - -resources_per_trial = {"cpu": 8, "gpu": 0.5} - - -def train_tune(config, training_data=None, epochs=100, spec=RECOGNITION_SPEC): - - hyper_params = RECOGNITION_PRETRAIN_HYPER_PARAMS.copy() - hyper_params.update(config) - - model = RecognitionPretrainModel(hyper_params=hyper_params, - output='./model', - spec=spec) - - data_module = PretrainDataModule(batch_size=hyper_params.pop('batch_size'), - pad=hyper_params.pop('pad'), - augment=hyper_params.pop('augment'), - training_data=training_data, - num_workers=resources_per_trial['cpu'], - height=model.height, - width=model.width, - channels=model.channels, - format_type='binary') - - callback = TuneReportCallback({'loss': 'CE'}, on='validation_end') - trainer = pl.Trainer(max_epochs=epochs, - accelerator='gpu', - devices=1, - callbacks=[callback], - enable_progress_bar=False) - trainer.fit(model, datamodule=data_module) - - -@click.command() -@click.option('-v', '--verbose', default=0, count=True) -@click.option('-s', '--seed', default=42, type=click.INT, - help='Seed for numpy\'s and torch\'s RNG. Set to a fixed value to ' - 'ensure reproducible random splits of data') -@click.option('-o', '--output', show_default=True, type=click.Path(), default='pretrain_hyper', help='output directory') -@click.option('-n', '--num-samples', show_default=True, type=int, default=100, help='Number of samples to train') -@click.option('-N', '--epochs', show_default=True, type=int, default=10, help='Maximum number of epochs to train per sample') -@click.option('-s', '--spec', show_default=True, default=RECOGNITION_SPEC, help='VGSL spec of the network to train.') -@click.option('-t', '--training-files', show_default=True, default=None, multiple=True, - callback=_validate_manifests, type=click.File(mode='r', lazy=True), - help='File(s) with additional paths to training data') -@click.argument('files', nargs=-1) -def cli(verbose, seed, output, num_samples, epochs, spec, training_files, files): - - files = list(files) - - if training_files: - files.extend(training_files) - - if not files: - raise click.UsageError('No training data was provided to the search command. Use `-t` or the `files` argument.') - - seed_everything(seed, workers=True) - - analysis = tune.run(partial(train_tune, - training_data=files, - epochs=epochs, - spec=spec), local_dir=output, num_samples=num_samples, resources_per_trial=resources_per_trial, config=config) - - click.echo("Best hyperparameters found were: ", analysis.get_best_config(metric='accuracy', mode='max')) - - -if __name__ == '__main__': - cli() diff --git a/kraken/contrib/hyperparameters/tune_training.py b/kraken/contrib/hyperparameters/tune_training.py deleted file mode 100644 index e123c3754..000000000 --- a/kraken/contrib/hyperparameters/tune_training.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python -""" -A script for a grid search over pretraining hyperparameters. -""" -import sys -from functools import partial - -import pytorch_lightning as pl -from ray import tune -from ray.tune.integration.pytorch_lightning import TuneReportCallback - -from kraken.lib.default_spec import (RECOGNITION_PRETRAIN_HYPER_PARAMS, - RECOGNITION_SPEC) -from kraken.lib.pretrain.model import (PretrainDataModule, - RecognitionPretrainModel) - -config = {'lrate': tune.loguniform(1e-8, 1e-2), - 'num_negatives': tune.qrandint(2, 100, 8), - 'mask_prob': tune.loguniform(0.01, 0.2), - 'mask_width': tune.qrandint(2, 8, 2)} - -resources_per_trial = {"cpu": 8, "gpu": 0.5} - - -def train_tune(config, training_data=None, epochs=100): - - hyper_params = RECOGNITION_PRETRAIN_HYPER_PARAMS.copy() - hyper_params.update(config) - - model = RecognitionPretrainModel(hyper_params=hyper_params, - output='model', - spec=RECOGNITION_SPEC) - - _ = PretrainDataModule(batch_size=hyper_params.pop('batch_size'), - pad=hyper_params.pop('pad'), - augment=hyper_params.pop('augment'), - training_data=training_data, - num_workers=resources_per_trial['cpu'], - height=model.height, - width=model.width, - channels=model.channels, - format_type='binary') - - callback = TuneReportCallback({'loss': 'CE'}, on='validation_end') - trainer = pl.Trainer(max_epochs=epochs, - gpus=1, - callbacks=[callback], - enable_progress_bar=False) - trainer.fit(model) - - -analysis = tune.run(partial(train_tune, training_data=sys.argv[2:]), - local_dir=sys.argv[1], - num_samples=100, - resources_per_trial=resources_per_trial, - config=config) - -print("Best hyperparameters found were: ", analysis.get_best_config(metric='accuracy', mode='max')) diff --git a/kraken/ketos/__init__.py b/kraken/ketos/__init__.py index 9edb8d005..4537691b3 100644 --- a/kraken/ketos/__init__.py +++ b/kraken/ketos/__init__.py @@ -60,10 +60,10 @@ def cli(ctx, verbose, seed, deterministic): ctx.meta['deterministic'] = False if not deterministic else 'warn' if seed: - from pytorch_lightning import seed_everything + from lightning.pytorch import seed_everything seed_everything(seed, workers=True) elif deterministic: - from pytorch_lightning import seed_everything + from lightning.pytorch import seed_everything seed_everything(42, workers=True) ctx.meta['verbose'] = verbose diff --git a/kraken/lib/pretrain/model.py b/kraken/lib/pretrain/model.py index d86d2a7c1..cd1d12e32 100644 --- a/kraken/lib/pretrain/model.py +++ b/kraken/lib/pretrain/model.py @@ -36,11 +36,11 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union import numpy as np -import pytorch_lightning as pl +import lightning as L import torch import torch.nn.functional as F -from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.utilities.memory import (garbage_collection_cuda, +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.utilities.memory import (garbage_collection_cuda, is_oom_error) from torch.optim import lr_scheduler from torch.utils.data import DataLoader, Subset, random_split @@ -73,7 +73,7 @@ def _star_fun(fun, kwargs): return None -class PretrainDataModule(pl.LightningDataModule): +class PretrainDataModule(L.LightningDataModule): def __init__(self, training_data: Union[Sequence[Union['PathLike', str]], Sequence[Dict[str, Any]]] = None, evaluation_data: Optional[Union[Sequence[Union['PathLike', str]], Sequence[Dict[str, Any]]]] = None, @@ -266,7 +266,7 @@ def setup(self, stage: Optional[str] = None): self.val_set.dataset.no_encode() -class RecognitionPretrainModel(pl.LightningModule): +class RecognitionPretrainModel(L.LightningModule): def __init__(self, hyper_params: Dict[str, Any] = None, output: str = 'model', diff --git a/kraken/lib/progress.py b/kraken/lib/progress.py index 25201a9be..1ba65a258 100644 --- a/kraken/lib/progress.py +++ b/kraken/lib/progress.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Union -from pytorch_lightning.callbacks.progress.rich_progress import ( +from lightning.pytorch.callbacks.progress.rich_progress import ( CustomProgress, MetricsTextColumn, RichProgressBar) from rich import get_console, reconfigure from rich.default_styles import DEFAULT_STYLES diff --git a/kraken/lib/ro/model.py b/kraken/lib/ro/model.py index cba78bd13..c9c661afa 100644 --- a/kraken/lib/ro/model.py +++ b/kraken/lib/ro/model.py @@ -23,10 +23,10 @@ Sequence, Union) import numpy as np -import pytorch_lightning as pl +import lightning as L import torch import torch.nn.functional as F -from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor +from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor from torch.optim import lr_scheduler from torch.utils.data import DataLoader, Subset @@ -64,7 +64,7 @@ def spearman_footrule_distance(s, t): return (s - t).abs().sum() / (0.5 * (len(s) ** 2 - (len(s) % 2))) -class ROModel(pl.LightningModule): +class ROModel(L.LightningModule): def __init__(self, hyper_params: Dict[str, Any] = None, output: str = 'model', diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 65b65b1e1..31d6532a8 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -23,10 +23,10 @@ from functools import partial import numpy as np -import pytorch_lightning as pl +import lightning as L import torch import torch.nn.functional as F -from pytorch_lightning.callbacks import (BaseFinetuning, Callback, +from lightning.pytorch.callbacks import (BaseFinetuning, Callback, EarlyStopping, LearningRateMonitor) from torch.optim import lr_scheduler from torch.utils.data import DataLoader, Subset, random_split @@ -66,21 +66,21 @@ def _validation_worker_init_fn(worker_id): results when validating. Temporarily increase the logging level for lightning because otherwise it will display a message at info level about the seed being changed. """ - from pytorch_lightning import seed_everything + from lightning.pytorch import seed_everything level = logging.getLogger("lightning_fabric.utilities.seed").level logging.getLogger("lightning_fabric.utilities.seed").setLevel(logging.WARN) seed_everything(42) logging.getLogger("lightning_fabric.utilities.seed").setLevel(level) -class KrakenTrainer(pl.Trainer): +class KrakenTrainer(L.Trainer): def __init__(self, enable_progress_bar: bool = True, enable_summary: bool = True, min_epochs: int = 5, max_epochs: int = 100, freeze_backbone=-1, - pl_logger: Union[pl.loggers.logger.Logger, str, None] = None, + pl_logger: Union[L.loggers.logger.Logger, str, None] = None, log_dir: Optional['PathLike'] = None, *args, **kwargs): @@ -93,14 +93,14 @@ def __init__(self, kwargs['callbacks'] = [kwargs['callbacks']] if pl_logger: - if 'logger' in kwargs and isinstance(kwargs['logger'], pl.loggers.logger.Logger): + if 'logger' in kwargs and isinstance(kwargs['logger'], L.loggers.logger.Logger): logger.debug('Experiment logger has been provided outside KrakenTrainer as `logger`') - elif isinstance(pl_logger, pl.loggers.logger.Logger): + elif isinstance(pl_logger, L.loggers.logger.Logger): logger.debug('Experiment logger has been provided outside KrakenTrainer as `pl_logger`') kwargs['logger'] = pl_logger elif pl_logger == 'tensorboard': logger.debug('Creating default experiment logger') - kwargs['logger'] = pl.loggers.TensorBoardLogger(log_dir) + kwargs['logger'] = L.loggers.TensorBoardLogger(log_dir) else: logger.error('`pl_logger` was set, but %s is not an accepted value', pl_logger) raise ValueError(f'{pl_logger} is not acceptable as logger') @@ -113,7 +113,7 @@ def __init__(self, kwargs['callbacks'].append(progress_bar_cb) if enable_summary: - from pytorch_lightning.callbacks import RichModelSummary + from lightning.pytorch.callbacks import RichModelSummary summary_cb = RichModelSummary(max_depth=2) kwargs['callbacks'].append(summary_cb) kwargs['enable_model_summary'] = False @@ -146,10 +146,10 @@ def freeze_before_training(self, pl_module): def finetune_function(self, pl_module, current_epoch, optimizer): pass - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: self.freeze(pl_module.net[:-1]) - def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch, batch_idx) -> None: + def on_train_batch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch, batch_idx) -> None: """ Called for each training batch. """ @@ -162,7 +162,7 @@ def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo current_param_groups = optimizer.param_groups self._store(pl_module, opt_idx, num_param_groups, current_param_groups) - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: """Called when the epoch begins.""" pass @@ -171,7 +171,7 @@ class KrakenSetOneChannelMode(Callback): """ Callback that sets the one_channel_mode of the model after the first epoch. """ - def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: # fill one_channel_mode after 1 iteration over training data set if not trainer.sanity_checking and trainer.current_epoch == 0 and trainer.model.nn.model_type == 'recognition': ds = getattr(pl_module, 'train_set', None) @@ -187,7 +187,7 @@ class KrakenSaveModel(Callback): """ Kraken's own serialization callback instead of pytorch's. """ - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None: if not trainer.sanity_checking: trainer.model.nn.hyper_params['completed_epochs'] += 1 metric = float(trainer.logged_metrics['val_metric']) if 'val_metric' in trainer.logged_metrics else -1.0 @@ -199,7 +199,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul trainer.model.best_model = f'{trainer.model.output}_{trainer.model.best_epoch}.mlmodel' -class RecognitionModel(pl.LightningModule): +class RecognitionModel(L.LightningModule): def __init__(self, hyper_params: Dict[str, Any] = None, output: str = 'model', @@ -706,7 +706,7 @@ def lr_scheduler_step(self, scheduler, metric): scheduler.step(metric) -class SegmentationModel(pl.LightningModule): +class SegmentationModel(L.LightningModule): def __init__(self, hyper_params: Dict = None, load_hyper_parameters: bool = False, diff --git a/setup.cfg b/setup.cfg index 982d5a815..4f880f868 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,14 +53,14 @@ install_requires = jinja2~=3.0 python-bidi torchvision>=0.5.0 - torch~=2.0.1 + torch~=2.2.0 scikit-learn~=1.2.1 scikit-image~=0.21.0 shapely~=1.8.5 pyarrow - lightning~=2.0.0 + lightning~=2.2.0 torchmetrics>=1.1.0 - threadpoolctl~=3.2.0 + threadpoolctl~=3.4.0 importlib-resources>=1.3.0 rich From 51d593a51ee4b51af174fa4100da9ee59f661da1 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sat, 20 Apr 2024 16:07:27 +0200 Subject: [PATCH 2/3] fix deps --- conda/meta.yaml | 2 +- environment.yml | 2 +- environment_cuda.yml | 2 +- kraken/lib/train.py | 8 ++++---- setup.cfg | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/conda/meta.yaml b/conda/meta.yaml index 4a06a0a91..1620e62cc 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -21,7 +21,7 @@ requirements: - scipy~=1.11.0 - jinja2~=3.0 - torchvision - - pytorch~=2.2.0 + - pytorch~=2.1.0 - cudatoolkit - jsonschema - scikit-image~=0.21.0 diff --git a/environment.yml b/environment.yml index 4787271dc..242b8426c 100644 --- a/environment.yml +++ b/environment.yml @@ -14,7 +14,7 @@ dependencies: - scipy~=1.10.0 - jinja2~=3.0 - conda-forge::torchvision-cpu>=0.5.0 - - conda-forge::pytorch-cpu~=2.2.0 + - conda-forge::pytorch-cpu~=2.1.0 - jsonschema - scikit-learn~=1.2.1 - scikit-image~=0.21.0 diff --git a/environment_cuda.yml b/environment_cuda.yml index ebdda92de..83b75f850 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -14,7 +14,7 @@ dependencies: - scipy~=1.10.0 - jinja2~=3.0 - conda-forge::torchvision>=0.5.0 - - conda-forge::pytorch~=2.2.0 + - conda-forge::pytorch~=2.1.0 - cudatoolkit>=9.2 - jsonschema - scikit-learn~=1.2.1 diff --git a/kraken/lib/train.py b/kraken/lib/train.py index 31d6532a8..ec9685aff 100644 --- a/kraken/lib/train.py +++ b/kraken/lib/train.py @@ -80,7 +80,7 @@ def __init__(self, min_epochs: int = 5, max_epochs: int = 100, freeze_backbone=-1, - pl_logger: Union[L.loggers.logger.Logger, str, None] = None, + pl_logger: Union[L.pytorch.loggers.logger.Logger, str, None] = None, log_dir: Optional['PathLike'] = None, *args, **kwargs): @@ -93,14 +93,14 @@ def __init__(self, kwargs['callbacks'] = [kwargs['callbacks']] if pl_logger: - if 'logger' in kwargs and isinstance(kwargs['logger'], L.loggers.logger.Logger): + if 'logger' in kwargs and isinstance(kwargs['logger'], L.pytorch.loggers.logger.Logger): logger.debug('Experiment logger has been provided outside KrakenTrainer as `logger`') - elif isinstance(pl_logger, L.loggers.logger.Logger): + elif isinstance(pl_logger, L.pytorch.loggers.logger.Logger): logger.debug('Experiment logger has been provided outside KrakenTrainer as `pl_logger`') kwargs['logger'] = pl_logger elif pl_logger == 'tensorboard': logger.debug('Creating default experiment logger') - kwargs['logger'] = L.loggers.TensorBoardLogger(log_dir) + kwargs['logger'] = L.pytorch.loggers.TensorBoardLogger(log_dir) else: logger.error('`pl_logger` was set, but %s is not an accepted value', pl_logger) raise ValueError(f'{pl_logger} is not acceptable as logger') diff --git a/setup.cfg b/setup.cfg index 4f880f868..5f675f35e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ install_requires = jinja2~=3.0 python-bidi torchvision>=0.5.0 - torch~=2.2.0 + torch~=2.1.0 scikit-learn~=1.2.1 scikit-image~=0.21.0 shapely~=1.8.5 From d72570d3adbadc9e56c9cd7d782e7607012c7b3e Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sat, 20 Apr 2024 17:16:26 +0200 Subject: [PATCH 3/3] do not use workers in ketos compile tests --- tests/test_newpolygons.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_newpolygons.py b/tests/test_newpolygons.py index 18d7faacd..6eb431727 100644 --- a/tests/test_newpolygons.py +++ b/tests/test_newpolygons.py @@ -380,7 +380,7 @@ def test_ketos_new_arrow(self): mfp2 = str(Path(tempdir) / "model2") self._test_ketoscli( - args=['compile', '-f', 'xml', '-o', dset, self.segmented_img], + args=['compile', '--workers', '0', '-f', 'xml', '-o', dset, self.segmented_img], expect_legacy=False, patching_dir="kraken.lib.arrow_dataset", ) @@ -399,7 +399,7 @@ def test_ketos_new_arrow_force_legacy(self): mfp2 = str(Path(tempdir) / "model2") self._test_ketoscli( - args=['compile', '--legacy-polygons', '-f', 'xml', '-o', dset, self.segmented_img], + args=['compile', '--workers', '0', '--legacy-polygons', '-f', 'xml', '-o', dset, self.segmented_img], expect_legacy=True, patching_dir="kraken.lib.arrow_dataset", ) @@ -428,7 +428,7 @@ def test_ketos_new_arrow_old_model(self): mfp2 = str(Path(tempdir) / "model2") self._test_ketoscli( - args=['compile', '-f', 'xml', '-o', dset, self.segmented_img], + args=['compile', '--workers', '0', '-f', 'xml', '-o', dset, self.segmented_img], expect_legacy=False, patching_dir="kraken.lib.arrow_dataset", ) @@ -445,7 +445,7 @@ def test_ketos_mixed_arrow_train_new(self): mfp = str(Path(tempdir) / "model") self._test_ketoscli( - args=['compile', '-f', 'xml', '-o', dset, self.segmented_img, self.arrow_data], + args=['compile', '--workers', '0', '-f', 'xml', '-o', dset, self.segmented_img, self.arrow_data], expect_legacy=False, patching_dir="kraken.lib.arrow_dataset", )