Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make remap labels more accurate #203

Merged
merged 2 commits into from
Apr 5, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from datumaro.components.cli_plugin import CliPlugin
import datumaro.util.mask_tools as mask_tools
from datumaro.util import parse_str_enum_value, NOTSET
from datumaro.util.annotation_util import find_group_leader, find_instances


Expand Down Expand Up @@ -433,7 +434,22 @@ def transform_item(self, item):
class RemapLabels(Transform, CliPlugin):
"""
Changes labels in the dataset.|n
|n
A label can be:|n
- renamed (and joined with existing) -|n
|s|swhen specified '--label <old_name>:<new_name>'|n
- deleted - when specified '--label <name>:' or default action is 'delete'|n
|s|sand the label is not mentioned in the list. When a label|n
|s|sis deleted, all the associated annotations are removed|n
- kept unchanged - when specified '--label <name>:<name>'|n
|s|sor default action is 'keep' and the label is not mentioned in the list|n
Annotations with no label are managed by the default action policy.|n
|n
Examples:|n
- Remove the 'person' label (and corresponding annotations):|n
|s|sremap_labels -l person: --default keep|n
- Rename 'person' to 'pedestrian' and 'human' to 'pedestrian', join:|n
|s|sremap_labels -l person:pedestrian -l human:pedestrian --default keep|n
- Rename 'person' to 'car' and 'cat' to 'dog', keep 'bus', remove others:|n
|s|sremap_labels -l person:car -l bus:bus -l cat:dog --default delete
"""
Expand Down Expand Up @@ -463,9 +479,9 @@ def build_cmdline_parser(cls, **kwargs):
def __init__(self, extractor, mapping, default=None):
super().__init__(extractor)

assert isinstance(default, (str, self.DefaultAction))
if isinstance(default, str):
default = self.DefaultAction[default]
default = parse_str_enum_value(default, self.DefaultAction,
self.DefaultAction.keep)
self._default_action = default

assert isinstance(mapping, (dict, list))
if isinstance(mapping, list):
Expand Down Expand Up @@ -503,10 +519,10 @@ def _make_label_id_map(self, src_label_cat, label_mapping, default_action):
dst_label_cat = LabelCategories(attributes=src_label_cat.attributes)
id_mapping = {}
for src_index, src_label in enumerate(src_label_cat.items):
dst_label = label_mapping.get(src_label.name)
if not dst_label and default_action == self.DefaultAction.keep:
dst_label = label_mapping.get(src_label.name, NOTSET)
if dst_label is NOTSET and default_action == self.DefaultAction.keep:
dst_label = src_label.name # keep unspecified as is
if not dst_label:
elif not dst_label or dst_label is NOTSET:
continue

dst_index = dst_label_cat.find(dst_label)[0]
Expand All @@ -518,7 +534,7 @@ def _make_label_id_map(self, src_label_cat, label_mapping, default_action):
if log.getLogger().isEnabledFor(log.DEBUG):
log.debug("Label mapping:")
for src_id, src_label in enumerate(src_label_cat.items):
if id_mapping.get(src_id):
if id_mapping.get(src_id) is not None:
log.debug("#%s '%s' -> #%s '%s'",
src_id, src_label.name, id_mapping[src_id],
dst_label_cat.items[id_mapping[src_id]].name
Expand All @@ -535,14 +551,11 @@ def categories(self):
def transform_item(self, item):
annotations = []
for ann in item.annotations:
if ann.type in { AnnotationType.label, AnnotationType.mask,
AnnotationType.points, AnnotationType.polygon,
AnnotationType.polyline, AnnotationType.bbox
} and ann.label is not None:
if getattr(ann, 'label') is not None:
conv_label = self._map_id(ann.label)
if conv_label is not None:
annotations.append(ann.wrap(label=conv_label))
else:
elif self._default_action is self.DefaultAction.keep:
annotations.append(ann.wrap())
return item.wrap(annotations=annotations)

Expand Down
45 changes: 28 additions & 17 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,18 @@ def test_remap_labels(self):
Bbox(1, 2, 3, 4, label=2),
Mask(image=np.array([1]), label=3),

# Should be kept
# Should be deleted
Polygon([1, 1, 2, 2, 3, 4], label=4),
PolyLine([1, 3, 4, 2, 5, 6])

# Should be kept
PolyLine([1, 3, 4, 2, 5, 6]),
Bbox(4, 3, 2, 1, label=5),
])
], categories={
AnnotationType.label: LabelCategories.from_iterable(
'label%s' % i for i in range(5)),
'label%s' % i for i in range(6)),
AnnotationType.mask: MaskCategories(
colormap=mask_tools.generate_colormap(5)),
colormap=mask_tools.generate_colormap(6)),
})

dst_dataset = Dataset.from_iterable([
Expand All @@ -353,37 +356,45 @@ def test_remap_labels(self):
Bbox(1, 2, 3, 4, label=0),
Mask(image=np.array([1]), label=1),

Polygon([1, 1, 2, 2, 3, 4], label=2),
PolyLine([1, 3, 4, 2, 5, 6], label=None)
PolyLine([1, 3, 4, 2, 5, 6], label=None),
Bbox(4, 3, 2, 1, label=2),
]),
], categories={
AnnotationType.label: LabelCategories.from_iterable(
['label0', 'label9', 'label4']),
['label0', 'label9', 'label5']),
AnnotationType.mask: MaskCategories(colormap={
k: v for k, v in mask_tools.generate_colormap(5).items()
if k in { 0, 1, 3, 4 }
k: v for k, v in mask_tools.generate_colormap(6).items()
if k in { 0, 1, 3, 5 }
})
})

actual = transforms.RemapLabels(src_dataset, mapping={
'label1': 'label9',
'label2': 'label0',
'label3': 'label9',
'label1': 'label9', # rename & join with new label9 (from label3)
'label2': 'label0', # rename & join with existing label0
'label3': 'label9', # rename & join with new label9 (form label1)
'label4': '', # delete the label and associated annotations
# 'label5' - unchanged
}, default='keep')

compare_datasets(self, dst_dataset, actual)

def test_remap_labels_delete_unspecified(self):
source_dataset = Dataset.from_iterable([
DatasetItem(id=1, annotations=[ Label(0) ])
], categories=['label0'])
DatasetItem(id=1, annotations=[
Label(0, id=0), # will be removed
Label(1, id=1),
Bbox(1, 2, 3, 4, label=None),
])
], categories=['label0', 'label1'])

target_dataset = Dataset.from_iterable([
DatasetItem(id=1),
], categories=[])
DatasetItem(id=1, annotations=[
Label(0, id=1),
]),
], categories=['label1'])

actual = transforms.RemapLabels(source_dataset,
mapping={}, default='delete')
mapping={ 'label1': 'label1' }, default='delete')

compare_datasets(self, target_dataset, actual)

Expand Down