diff --git a/datumaro/datumaro/plugins/transforms.py b/datumaro/datumaro/plugins/transforms.py index 7449d4fdeb14..6c418387b767 100644 --- a/datumaro/datumaro/plugins/transforms.py +++ b/datumaro/datumaro/plugins/transforms.py @@ -9,7 +9,7 @@ import pycocotools.mask as mask_utils from datumaro.components.extractor import (Transform, AnnotationType, - Mask, RleMask, Polygon) + Mask, RleMask, Polygon, Bbox) from datumaro.components.cli_plugin import CliPlugin import datumaro.util.mask_tools as mask_tools @@ -237,6 +237,25 @@ def convert_mask(mask): for p in polygons ] +class ShapesToBoxes(Transform, CliPlugin): + def transform_item(self, item): + annotations = [] + for ann in item.annotations: + if ann.type in { AnnotationType.mask, AnnotationType.polygon, + AnnotationType.polyline, AnnotationType.points, + }: + annotations.append(self.convert_shape(ann)) + else: + annotations.append(ann) + + return self.wrap_item(item, annotations=annotations) + + @staticmethod + def convert_shape(shape): + bbox = shape.get_bbox() + return Bbox(*bbox, label=shape.label, z_order=shape.z_order, + id=shape.id, attributes=shape.attributes, group=shape.group) + class Reindex(Transform, CliPlugin): @classmethod def build_cmdline_parser(cls, **kwargs): @@ -253,3 +272,34 @@ def __init__(self, extractor, start=1): def __iter__(self): for i, item in enumerate(self._extractor): yield self.wrap_item(item, id=i + self._start) + + +class MapSubsets(Transform, CliPlugin): + @staticmethod + def _mapping_arg(s): + parts = s.split(':') + if len(parts) != 2: + import argparse + raise argparse.ArgumentTypeError() + return parts + + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument('-s', '--subset', action='append', + type=cls._mapping_arg, dest='mapping', + help="Subset mapping of the form: 'src:dst' (repeatable)") + return parser + + def __init__(self, extractor, mapping=None): + super().__init__(extractor) + + if mapping is None: + mapping = {} + elif not isinstance(mapping, dict): + mapping = dict(tuple(m) for m in mapping) + self._mapping = mapping + + def transform_item(self, item): + return self.wrap_item(item, + subset=self._mapping.get(item.subset, item.subset)) \ No newline at end of file diff --git a/datumaro/tests/test_transforms.py b/datumaro/tests/test_transforms.py index e5f0600af3d4..41daf1734938 100644 --- a/datumaro/tests/test_transforms.py +++ b/datumaro/tests/test_transforms.py @@ -3,7 +3,7 @@ from unittest import TestCase from datumaro.components.extractor import (Extractor, DatasetItem, - Mask, Polygon + Mask, Polygon, PolyLine, Points, Bbox ) from datumaro.util.test_utils import compare_datasets import datumaro.plugins.transforms as transforms @@ -185,4 +185,62 @@ def __iter__(self): actual = transforms.MergeInstanceSegments(SrcExtractor(), include_polygons=True) - compare_datasets(self, DstExtractor(), actual) \ No newline at end of file + compare_datasets(self, DstExtractor(), actual) + + def test_map_subsets(self): + class SrcExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset='a'), + DatasetItem(id=2, subset='b'), + DatasetItem(id=3, subset='c'), + ]) + + class DstExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset=''), + DatasetItem(id=2, subset='a'), + DatasetItem(id=3, subset='c'), + ]) + + actual = transforms.MapSubsets(SrcExtractor(), + { 'a': '', 'b': 'a' }) + compare_datasets(self, DstExtractor(), actual) + + def test_shapes_to_boxes(self): + class SrcExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, image=np.zeros((5, 5, 3)), + annotations=[ + Mask(np.array([ + [0, 0, 1, 1, 1], + [0, 0, 0, 0, 1], + [1, 0, 0, 0, 1], + [1, 0, 0, 0, 0], + [1, 1, 1, 0, 0]], + ), id=1), + Polygon([1, 1, 4, 1, 4, 4, 1, 4], id=2), + PolyLine([1, 1, 2, 1, 2, 2, 1, 2], id=3), + Points([2, 2, 4, 2, 4, 4, 2, 4], id=4), + ] + ), + ]) + + class DstExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, image=np.zeros((5, 5, 3)), + annotations=[ + Bbox(0, 0, 4, 4, id=1), + Bbox(1, 1, 3, 3, id=2), + Bbox(1, 1, 1, 1, id=3), + Bbox(2, 2, 2, 2, id=4), + ] + ), + ]) + + actual = transforms.ShapesToBoxes(SrcExtractor()) + compare_datasets(self, DstExtractor(), actual) +