From e76224e84b38e313f4fdef1adb9502d4a7b8e1b3 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Fri, 8 Sep 2023 22:46:08 +0900 Subject: [PATCH 1/4] Add ThreadPool option to ModelTransform and SAMBboxToInstanceMask Signed-off-by: Kim, Vinnam --- .../18_bbox_to_instance_mask_using_sam.ipynb | 24 +++-- src/datumaro/components/dataset.py | 10 +- src/datumaro/components/hl_ops/__init__.py | 11 ++- src/datumaro/components/transformer.py | 76 ++++++++++++++-- .../sam_transforms/bbox_to_inst_mask.py | 37 +++++--- src/datumaro/util/multi_procs_util.py | 91 +++++++++++++++++++ tests/unit/components/test_transformer.py | 65 +++++++++++++ tests/unit/test_util.py | 40 ++++++++ tests/unit/transforms/test_sam_transforms.py | 4 +- 9 files changed, 326 insertions(+), 32 deletions(-) create mode 100644 src/datumaro/util/multi_procs_util.py create mode 100644 tests/unit/components/test_transformer.py diff --git a/notebooks/18_bbox_to_instance_mask_using_sam.ipynb b/notebooks/18_bbox_to_instance_mask_using_sam.ipynb index 3c864786a8..37d6efac85 100644 --- a/notebooks/18_bbox_to_instance_mask_using_sam.ipynb +++ b/notebooks/18_bbox_to_instance_mask_using_sam.ipynb @@ -153,14 +153,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now, we apply `SAMBboxToInstanceMask` to the dataset.\n", - "This transform requires some arguments to execute properly.\n", - "`inference_server_type` is the type of inference server which SAM encoder and decoder are deployed.\n", - "In this example, we launched the OpenVINO™ Model Server instance, thus please choose `InferenceServerType.ovms`.\n", - "The gRPC endpoint address was `localhost:8001`.\n", - "Therefore, `host=\"localhost\"`, `port=8001`, and `protocol_type=ProtocolType.grpc` should be given.\n", - "Lastly, you can make `Polygon` output for the instance mask, but this time we assign `to_polygon` to `False`,\n", - "so that the output will be `Mask` annotation type." + "Now, we apply `SAMBboxToInstanceMask` to the dataset. This transform requires several arguments to execute properly.\n", + "\n", + "`inference_server_type` represents the type of inference server on which SAM encoder and decoder are deployed. In this example, we launched the OpenVINO™ Model Server instance. Therefore, please select `InferenceServerType.ovms`.\n", + "\n", + "The gRPC endpoint address was `localhost:8001`. To configure this, provide the following parameters:\n", + "- `host=\"localhost\"`\n", + "- `port=8001`\n", + "- `protocol_type=ProtocolType.grpc`\n", + "\n", + "You can also specify a `timeout=60.0` value, which represents the maximum seconds to wait for a response from the server instance.\n", + "\n", + "Additionally, you can choose to produce `Polygon` output for the instance mask. However, in this case, we have set `to_polygon` to `False`, resulting in an output of the `Mask` annotation type.\n", + "\n", + "Lastly, we've set `num_workers=0`. This means we will use synchronous iteration to send a model inference request to the server instance and wait for the inference results. If you need to handle multiple inference requests concurrently, you can increase this value to utilize a thread pool. This is particularly useful when dealing with server instances that have high throughput." ] }, { @@ -178,8 +184,10 @@ " inference_server_type=InferenceServerType.ovms,\n", " host=\"localhost\",\n", " port=8001,\n", + " timeout=60.0,\n", " protocol_type=ProtocolType.grpc,\n", " to_polygon=False,\n", + " num_workers=0,\n", " )" ] }, diff --git a/src/datumaro/components/dataset.py b/src/datumaro/components/dataset.py index bdb1276888..2f96c5cafe 100644 --- a/src/datumaro/components/dataset.py +++ b/src/datumaro/components/dataset.py @@ -443,6 +443,7 @@ def run_model( *, batch_size: int = 1, append_annotation: bool = False, + num_workers: int = 0, **kwargs, ) -> Dataset: """ @@ -454,6 +455,8 @@ def run_model( batch_size: The number of dataset items processed simultaneously by the model append_annotation: Whether append new annotation to existed annotations + num_workers: The number of worker threads to use for parallel inference. + Set to 0 for single-process mode. Default is 0. **kwargs: Parameters for the model Returns: self @@ -465,11 +468,16 @@ def run_model( launcher=model, batch_size=batch_size, append_annotation=append_annotation, + num_workers=num_workers, **kwargs, ) elif inspect.isclass(model) and isinstance(model, ModelTransform): return self.transform( - model, batch_size=batch_size, append_annotation=append_annotation, **kwargs + model, + batch_size=batch_size, + append_annotation=append_annotation, + num_workers=num_workers, + **kwargs, ) else: raise TypeError("Unexpected 'model' argument type: %s" % type(model)) diff --git a/src/datumaro/components/hl_ops/__init__.py b/src/datumaro/components/hl_ops/__init__.py index 368411d8f3..ef1054b48a 100644 --- a/src/datumaro/components/hl_ops/__init__.py +++ b/src/datumaro/components/hl_ops/__init__.py @@ -195,6 +195,7 @@ def run_model( *, batch_size: int = 1, append_annotation: bool = False, + num_workers: int = 0, **kwargs, ) -> IDataset: """ @@ -207,6 +208,8 @@ def run_model( batch_size: The number of dataset items processed simultaneously by the model append_annotation: Whether append new annotation to existed annotations + num_workers: The number of worker threads to use for parallel inference. + Set to 0 for single-process mode. Default is 0. **kwargs: Parameters for the model Returns: a wrapper around the input dataset, which is computed lazily @@ -220,11 +223,17 @@ def run_model( launcher=model, batch_size=batch_size, append_annotation=append_annotation, + num_workers=num_workers, **kwargs, ) elif inspect.isclass(model) and issubclass(model, ModelTransform): return HLOps.transform( - dataset, model, batch_size=batch_size, append_annotation=append_annotation, **kwargs + dataset, + model, + batch_size=batch_size, + append_annotation=append_annotation, + num_workers=num_workers, + **kwargs, ) else: raise TypeError(f"Unexpected model argument type: {type(model)}") diff --git a/src/datumaro/components/transformer.py b/src/datumaro/components/transformer.py index 3752ab864f..be27cb6ccc 100644 --- a/src/datumaro/components/transformer.py +++ b/src/datumaro/components/transformer.py @@ -1,7 +1,8 @@ # Copyright (C) 2019-2022 Intel Corporation # # SPDX-License-Identifier: MIT -from typing import Generator, List, Optional +from multiprocessing.pool import ThreadPool +from typing import Generator, Iterator, List, Optional import numpy as np @@ -10,6 +11,7 @@ from datumaro.components.dataset_base import DatasetBase, DatasetItem, IDataset from datumaro.components.launcher import Launcher from datumaro.util import is_method_redefined, take_by +from datumaro.util.multi_procs_util import consumer_generator class Transform(DatasetBase, CliPlugin): @@ -71,25 +73,65 @@ def __iter__(self): class ModelTransform(Transform): + """A transformation class for applying a model's inference to dataset items. + + This class takes an dataset, a launcher, and other optional parameters + to transform the dataset item from the model outputs by the launcher. + It can process items using multiple processes if specified, making it suitable for + parallelized inference tasks. + + Parameters: + extractor: The dataset extractor to obtain items from. + launcher: The launcher responsible for model inference. + batch_size: The batch size for processing items. Default is 1. + append_annotation: Whether to append inference annotations to existing annotations. + Default is False. + num_workers: The number of worker threads to use for parallel inference. + Set to 0 for single-process mode. Default is 0. + """ + def __init__( self, extractor: IDataset, launcher: Launcher, batch_size: int = 1, append_annotation: bool = False, + num_workers: int = 0, ): super().__init__(extractor) self._launcher = launcher self._batch_size = batch_size self._append_annotation = append_annotation - - def __iter__(self) -> Generator[DatasetItem, None, None]: - for batch in take_by(self._extractor, self._batch_size): - inference = self._launcher.launch( - [item for item in batch if self._launcher.type_check(item)] + if not (isinstance(num_workers, int) and num_workers >= 0): + raise ValueError( + f"num_workers should be a non negative integer, but it is {num_workers}" ) - - for item in self._yield_item(batch, inference): + self._num_workers = num_workers + + def __iter__(self) -> Iterator[DatasetItem]: + if self._num_workers == 0: + return self._iter_single_proc() + return self._iter_multi_procs() + + def _iter_multi_procs(self): + with ThreadPool(processes=self._num_workers) as pool: + + def _producer_gen(): + for batch in take_by(self._extractor, self._batch_size): + future = pool.apply_async( + func=self._process_batch, + args=(batch,), + ) + yield future + + with consumer_generator(producer_generator=_producer_gen) as consumer_gen: + for future in consumer_gen(): + for item in future.get(): + yield item + + def _iter_single_proc(self) -> Iterator[DatasetItem]: + for batch in take_by(self._extractor, self._batch_size): + for item in self._process_batch(batch=batch): yield item def _yield_item( @@ -101,6 +143,24 @@ def _yield_item( annotations = item.annotations + annotations yield self.wrap_item(item, annotations=annotations) + def _process_batch( + self, + batch: List[DatasetItem], + ) -> List[DatasetItem]: + inference = self._launcher.launch( + batch=[item for item in batch if self._launcher.type_check(item)] + ) + + return [ + Transform.wrap_item( + item, + annotations=item.annotations + annotations + if self._append_annotation + else annotations, + ) + for item, annotations in zip(batch, inference) + ] + def get_subset(self, name): subset = self._extractor.get_subset(name) return __class__(subset, self._launcher, self._batch_size) diff --git a/src/datumaro/plugins/sam_transforms/bbox_to_inst_mask.py b/src/datumaro/plugins/sam_transforms/bbox_to_inst_mask.py index ac2e47f822..309217e6fb 100644 --- a/src/datumaro/plugins/sam_transforms/bbox_to_inst_mask.py +++ b/src/datumaro/plugins/sam_transforms/bbox_to_inst_mask.py @@ -4,7 +4,7 @@ """Bbox-to-instance mask transform using Segment Anything Model""" import os.path as osp -from typing import Generator, List, Optional +from typing import List, Optional import datumaro.plugins.sam_transforms.interpreters.sam_decoder_for_bbox as sam_decoder_for_bbox_interp import datumaro.plugins.sam_transforms.interpreters.sam_encoder as sam_encoder_interp @@ -18,7 +18,6 @@ ProtocolType, TLSConfig, ) -from datumaro.util import take_by from datumaro.util.mask_tools import extract_contours __all__ = ["SAMBboxToInstanceMask"] @@ -44,6 +43,8 @@ class SAMBboxToInstanceMask(ModelTransform, CliPlugin): tls_config: Configuration required if the server instance is in the secure mode protocol_type: Communication protocol type with the server instance to_polygon: If true, the output `Mask` annotations will be converted to `Polygon` annotations. + num_workers: The number of worker threads to use for parallel inference. + Set to 0 for single-process mode. Default is 0. """ def __init__( @@ -56,6 +57,7 @@ def __init__( tls_config: Optional[TLSConfig] = None, protocol_type: ProtocolType = ProtocolType.grpc, to_polygon: bool = False, + num_workers: int = 0, ): if inference_server_type == InferenceServerType.ovms: launcher_cls = OVMSLauncher @@ -90,26 +92,35 @@ def __init__( launcher=self._sam_encoder_launcher, batch_size=1, append_annotation=False, + num_workers=num_workers, ) self._to_polygon = to_polygon - def __iter__(self) -> Generator[DatasetItem, None, None]: - for batch in take_by(self._extractor, self._batch_size): - batch = [item for item in batch if self._launcher.type_check(item)] - img_embeds = self._sam_encoder_launcher.launch(batch) + def _process_batch( + self, + batch: List[DatasetItem], + ) -> List[DatasetItem]: + print("Process batch") + img_embeds = self._sam_encoder_launcher.launch( + batch=[item for item in batch if self._sam_encoder_launcher.type_check(item)] + ) - for item, img_embed in zip(batch, img_embeds): - # Nested list of mask [[mask_0, ...]] - nested_masks: List[List[Mask]] = self._sam_decoder_launcher.launch( - [item.wrap(annotations=item.annotations + img_embed)], - stack=False, - ) + items = [] + for item, img_embed in zip(batch, img_embeds): + # Nested list of mask [[mask_0, ...]] + nested_masks: List[List[Mask]] = self._sam_decoder_launcher.launch( + [item.wrap(annotations=item.annotations + img_embed)], + stack=False, + ) - yield item.wrap( + items.append( + item.wrap( annotations=self._convert_to_polygon(nested_masks[0]) if self._to_polygon else nested_masks[0] ) + ) + return items @staticmethod def _convert_to_polygon(masks: List[Mask]): diff --git a/src/datumaro/util/multi_procs_util.py b/src/datumaro/util/multi_procs_util.py new file mode 100644 index 0000000000..d6bdc329f4 --- /dev/null +++ b/src/datumaro/util/multi_procs_util.py @@ -0,0 +1,91 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import logging as log +from contextlib import contextmanager +from enum import IntEnum +from queue import Full, Queue +from threading import Condition, Thread +from typing import Any, Iterator, TypeVar + +__all__ = ["consumer_generator"] + + +class ProducerMessage(IntEnum): + START = 0 + END = 1 + + +Item = TypeVar("Item") + + +@contextmanager +def consumer_generator( + producer_generator: Iterator[Item], + queue_size: int = 100, + enqueue_timeout: float = 5.0, + join_timeout: float = 10.0, +) -> Iterator[Item]: + """Context manager that creates a generator to consume items produced by another generator. + + This context manager sets up a producer thread that generates items from the `producer_generator` + and enqueues them to be consumed by the consumer generator, which is also created by this function. + + Parameters: + producer_generator: A generator that produces items. + queue_size: The maximum size of the shared queue between the producer and consumer. + enqueue_timeout: The maximum time to wait for enqueuing an item to the queue if it's full. + join_timeout: The maximum time to wait for the producer thread to finish when exiting the context. + + Returns: + Iterator: A context for iterating over the generated items. + """ + queue = Queue(maxsize=queue_size) + lock = Condition() + is_terminated = False + + def _enqueue(item: Any, queue: Queue): + while True: + try: + queue.put(item, block=True, timeout=enqueue_timeout) + return + except Full: + with lock: + if is_terminated: + raise RuntimeError( + "Item to enqueue is left. However, the main process is terminated." + ) + + def _target(queue: Queue) -> None: + try: + _enqueue(ProducerMessage.START, queue) + + for item in producer_generator(): + _enqueue(item, queue) + + _enqueue(ProducerMessage.END, queue) + except RuntimeError as e: + log.error(e) + return + + producer = Thread(target=_target, args=(queue,)) + producer.start() + + def _generator(): + while True: + item = queue.get() + + if item == ProducerMessage.START: + continue + elif item == ProducerMessage.END: + return + + yield item + + try: + yield _generator + finally: + with lock: + is_terminated = True + producer.join(timeout=join_timeout) diff --git a/tests/unit/components/test_transformer.py b/tests/unit/components/test_transformer.py new file mode 100644 index 0000000000..400800e180 --- /dev/null +++ b/tests/unit/components/test_transformer.py @@ -0,0 +1,65 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from typing import List, Tuple + +import pytest + +from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred, PrepInfo +from datumaro.components.annotation import Annotation +from datumaro.components.dataset import Dataset +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.launcher import Launcher +from datumaro.components.transformer import ModelTransform + + +class MockLauncher(Launcher): + def preprocess(self, item: DatasetItem) -> Tuple[LauncherInputType, PrepInfo]: + return {"item": item}, None + + def infer(self, inputs: LauncherInputType) -> List[ModelPred]: + return [[Annotation(id=1)] for _ in inputs["item"]] + + def postprocess(self, pred: ModelPred, info: PrepInfo) -> List[Annotation]: + return pred + + +class ModelTransformTest: + @pytest.fixture + def fxt_dataset(self): + return Dataset.from_iterable( + [ + DatasetItem( + id=f"item_{i}", + annotations=[Annotation(id=0)], + ) + for i in range(10) + ] + ) + + @pytest.mark.parametrize("batch_size", [1, 10]) + @pytest.mark.parametrize("append_annotation", [True, False]) + @pytest.mark.parametrize("num_workers", [0, 2]) + def test_model_transform( + self, + fxt_dataset: Dataset, + batch_size, + append_annotation, + num_workers, + ): + transform = ModelTransform( + extractor=fxt_dataset, + launcher=MockLauncher(), + batch_size=batch_size, + append_annotation=append_annotation, + num_workers=num_workers, + ) + + for idx, item in enumerate(transform): + assert item.id == f"item_{idx}" + + if append_annotation: + assert item.annotations == [Annotation(id=0), Annotation(id=1)] + else: + assert item.annotations == [Annotation(id=1)] diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index f2059cb35e..8b3d23c140 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -7,12 +7,14 @@ import os.path as osp import platform from contextlib import suppress +from typing import Iterator from unittest import TestCase, mock import pytest from datumaro.util import is_method_redefined from datumaro.util.definitions import get_datumaro_cache_dir +from datumaro.util.multi_procs_util import consumer_generator from datumaro.util.os_util import walk from datumaro.util.scope import Scope, on_error_do, on_exit_do, scoped @@ -242,3 +244,41 @@ def test_get_datumaro_cache_dir( with caplog.at_level(logging.ERROR): get_datumaro_cache_dir(fxt_non_writable_path) assert len(caplog.records) == 1 + + +class MultiProcUtilTest: + @pytest.fixture + def fxt_producer_generator(self): + class TestObject: + def __init__(self, value: int) -> None: + self.value = value + + def test_func() -> Iterator[TestObject]: + for i in range(1000): + yield TestObject(i) + + return test_func + + def test_succeed(self, fxt_producer_generator): + with consumer_generator(producer_generator=fxt_producer_generator) as f: + for expect, actual in enumerate(f()): + assert expect == actual.value + + def test_raise_exception_in_main_thread( + self, fxt_producer_generator, caplog: pytest.LogCaptureFixture + ): + try: + with consumer_generator( + producer_generator=fxt_producer_generator, + enqueue_timeout=0.05, + join_timeout=0.1, + ) as f: + for expect, actual in enumerate(f()): + assert expect == actual.value + raise Exception() + except Exception: + assert any( + "Item to enqueue is left. However, the main process is terminated." + == record.message + for record in caplog.records + ) diff --git a/tests/unit/transforms/test_sam_transforms.py b/tests/unit/transforms/test_sam_transforms.py index 9885172169..0c9c229b42 100644 --- a/tests/unit/transforms/test_sam_transforms.py +++ b/tests/unit/transforms/test_sam_transforms.py @@ -126,7 +126,8 @@ def fxt_inference_server_type(self, request): def fxt_to_polygon(self, request): return request.param - def test_transform(self, fxt_dataset, fxt_inference_server_type, fxt_to_polygon): + @pytest.mark.parametrize("num_workers", [0, 2]) + def test_transform(self, fxt_dataset, fxt_inference_server_type, fxt_to_polygon, num_workers): if fxt_inference_server_type == InferenceServerType.ovms: launcher_str = "OVMSLauncher" elif fxt_inference_server_type == InferenceServerType.triton: @@ -145,6 +146,7 @@ def test_transform(self, fxt_dataset, fxt_inference_server_type, fxt_to_polygon) extractor=fxt_dataset, inference_server_type=fxt_inference_server_type, to_polygon=fxt_to_polygon, + num_workers=num_workers, ) mock_sam_encoder.launch.return_value = [[FeatureVector(vector=np.zeros([10]))]] From b1721b01f48ce1a9cb8e4c412b9215463efcb04e Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Mon, 11 Sep 2023 18:56:03 +0900 Subject: [PATCH 2/4] Update CHANGELOG.md Signed-off-by: Kim, Vinnam --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d32c142d7..e0c9e7bd68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - Remove deprecates announced to be removed in 1.5.0 () +- Add multi-threading option to ModelTransform and SAMBboxToInstanceMask + () ### Bug fixes - Fix bugs for Tile transform From ff99cec1bbee8ffb4d66d54b67d8741a40c9297c Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Mon, 11 Sep 2023 20:08:55 +0900 Subject: [PATCH 3/4] Fix error Signed-off-by: Kim, Vinnam --- src/datumaro/components/transformer.py | 16 +++++--------- .../plugins/missing_annotation_detection.py | 21 +++++++++++++------ 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/datumaro/components/transformer.py b/src/datumaro/components/transformer.py index be27cb6ccc..becedd2da2 100644 --- a/src/datumaro/components/transformer.py +++ b/src/datumaro/components/transformer.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: MIT from multiprocessing.pool import ThreadPool -from typing import Generator, Iterator, List, Optional +from typing import Iterator, List, Optional import numpy as np @@ -134,15 +134,6 @@ def _iter_single_proc(self) -> Iterator[DatasetItem]: for item in self._process_batch(batch=batch): yield item - def _yield_item( - self, batch: List[DatasetItem], inference: List[List[Annotation]] - ) -> Generator[DatasetItem, None, None]: - for item, annotations in zip(batch, inference): - self._check_annotations(annotations) - if self._append_annotation: - annotations = item.annotations + annotations - yield self.wrap_item(item, annotations=annotations) - def _process_batch( self, batch: List[DatasetItem], @@ -151,8 +142,11 @@ def _process_batch( batch=[item for item in batch if self._launcher.type_check(item)] ) + for annotations in inference: + self._check_annotations(annotations) + return [ - Transform.wrap_item( + self.wrap_item( item, annotations=item.annotations + annotations if self._append_annotation diff --git a/src/datumaro/plugins/missing_annotation_detection.py b/src/datumaro/plugins/missing_annotation_detection.py index 5277bee488..c60488ceb0 100644 --- a/src/datumaro/plugins/missing_annotation_detection.py +++ b/src/datumaro/plugins/missing_annotation_detection.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: MIT -from typing import Generator, List, Optional, Set +from typing import List, Optional, Set from datumaro.components.abstracts.merger import IMatcherContext from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories @@ -83,18 +83,27 @@ def get_any_label_name(self, ann: Annotation, label_id: int) -> str: ), } - def _yield_item( - self, batch: List[DatasetItem], inference: List[List[Annotation]] - ) -> Generator[DatasetItem, None, None]: - for item, annotations in zip(batch, inference): + def _process_batch( + self, + batch: List[DatasetItem], + ) -> List[DatasetItem]: + inference = self._launcher.launch( + batch=[item for item in batch if self._launcher.type_check(item)] + ) + + for annotations in inference: self._check_annotations(annotations) - yield self.wrap_item( + + return [ + self.wrap_item( item, annotations=self._find_missing_anns( gt_anns=item.annotations, pseudo_anns=self._apply_score_threshold(annotations), ), ) + for item, annotations in zip(batch, inference) + ] def _apply_score_threshold(self, annotations: List[Annotation]) -> List[Annotation]: if self._score_threshold is None: From 865fb471695a9635ddd37b76383e1adc746aac1c Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Tue, 12 Sep 2023 15:38:23 +0900 Subject: [PATCH 4/4] Fix wrong typing and inplace bbox annotations only Signed-off-by: Kim, Vinnam --- src/datumaro/components/transformer.py | 4 +-- .../sam_transforms/bbox_to_inst_mask.py | 30 +++++++++++++------ src/datumaro/util/multi_procs_util.py | 10 +++---- tests/unit/test_util.py | 8 ++--- 4 files changed, 32 insertions(+), 20 deletions(-) diff --git a/src/datumaro/components/transformer.py b/src/datumaro/components/transformer.py index becedd2da2..c5d743bbc3 100644 --- a/src/datumaro/components/transformer.py +++ b/src/datumaro/components/transformer.py @@ -124,8 +124,8 @@ def _producer_gen(): ) yield future - with consumer_generator(producer_generator=_producer_gen) as consumer_gen: - for future in consumer_gen(): + with consumer_generator(producer_generator=_producer_gen()) as consumer_gen: + for future in consumer_gen: for item in future.get(): yield item diff --git a/src/datumaro/plugins/sam_transforms/bbox_to_inst_mask.py b/src/datumaro/plugins/sam_transforms/bbox_to_inst_mask.py index 309217e6fb..43d2784577 100644 --- a/src/datumaro/plugins/sam_transforms/bbox_to_inst_mask.py +++ b/src/datumaro/plugins/sam_transforms/bbox_to_inst_mask.py @@ -8,7 +8,7 @@ import datumaro.plugins.sam_transforms.interpreters.sam_decoder_for_bbox as sam_decoder_for_bbox_interp import datumaro.plugins.sam_transforms.interpreters.sam_encoder as sam_encoder_interp -from datumaro.components.annotation import Mask, Polygon +from datumaro.components.annotation import Bbox, Mask, Polygon from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset_base import DatasetItem, IDataset from datumaro.components.transformer import ModelTransform @@ -100,26 +100,38 @@ def _process_batch( self, batch: List[DatasetItem], ) -> List[DatasetItem]: - print("Process batch") img_embeds = self._sam_encoder_launcher.launch( batch=[item for item in batch if self._sam_encoder_launcher.type_check(item)] ) items = [] for item, img_embed in zip(batch, img_embeds): + item_to_decode = item.wrap(annotations=item.annotations + img_embed) + + if not any(isinstance(ann, Bbox) for ann in item_to_decode.annotations): + item_to_decode.annotations.pop() # Pop the added image embedding + items.append(item_to_decode) + continue + # Nested list of mask [[mask_0, ...]] nested_masks: List[List[Mask]] = self._sam_decoder_launcher.launch( - [item.wrap(annotations=item.annotations + img_embed)], + [item_to_decode], stack=False, ) - items.append( - item.wrap( - annotations=self._convert_to_polygon(nested_masks[0]) - if self._to_polygon - else nested_masks[0] - ) + # Pop the added image embedding + item_to_decode.annotations.pop() + # Leave non-bbox annotations only + item_to_decode.annotations = [ + ann for ann in item_to_decode.annotations if not isinstance(ann, Bbox) + ] + + item_to_decode.annotations += ( + self._convert_to_polygon(nested_masks[0]) if self._to_polygon else nested_masks[0] ) + + items.append(item_to_decode) + return items @staticmethod diff --git a/src/datumaro/util/multi_procs_util.py b/src/datumaro/util/multi_procs_util.py index d6bdc329f4..95268b952c 100644 --- a/src/datumaro/util/multi_procs_util.py +++ b/src/datumaro/util/multi_procs_util.py @@ -7,7 +7,7 @@ from enum import IntEnum from queue import Full, Queue from threading import Condition, Thread -from typing import Any, Iterator, TypeVar +from typing import Any, Generator, Iterator, TypeVar __all__ = ["consumer_generator"] @@ -26,7 +26,7 @@ def consumer_generator( queue_size: int = 100, enqueue_timeout: float = 5.0, join_timeout: float = 10.0, -) -> Iterator[Item]: +) -> Generator[Iterator[Item], None, None]: """Context manager that creates a generator to consume items produced by another generator. This context manager sets up a producer thread that generates items from the `producer_generator` @@ -61,7 +61,7 @@ def _target(queue: Queue) -> None: try: _enqueue(ProducerMessage.START, queue) - for item in producer_generator(): + for item in producer_generator: _enqueue(item, queue) _enqueue(ProducerMessage.END, queue) @@ -72,7 +72,7 @@ def _target(queue: Queue) -> None: producer = Thread(target=_target, args=(queue,)) producer.start() - def _generator(): + def _generator() -> Iterator[Item]: while True: item = queue.get() @@ -84,7 +84,7 @@ def _generator(): yield item try: - yield _generator + yield _generator() finally: with lock: is_terminated = True diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 8b3d23c140..118a7af21a 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -260,8 +260,8 @@ def test_func() -> Iterator[TestObject]: return test_func def test_succeed(self, fxt_producer_generator): - with consumer_generator(producer_generator=fxt_producer_generator) as f: - for expect, actual in enumerate(f()): + with consumer_generator(producer_generator=fxt_producer_generator()) as f: + for expect, actual in enumerate(f): assert expect == actual.value def test_raise_exception_in_main_thread( @@ -269,11 +269,11 @@ def test_raise_exception_in_main_thread( ): try: with consumer_generator( - producer_generator=fxt_producer_generator, + producer_generator=fxt_producer_generator(), enqueue_timeout=0.05, join_timeout=0.1, ) as f: - for expect, actual in enumerate(f()): + for expect, actual in enumerate(f): assert expect == actual.value raise Exception() except Exception: