From 0e670a560f5afee3c4be5f6d3ddf005e3d79612a Mon Sep 17 00:00:00 2001 From: Zhiltsov Max Date: Wed, 12 Aug 2020 18:20:41 +0300 Subject: [PATCH] add tests --- .../datumaro/cli/contexts/project/__init__.py | 19 +-- datumaro/datumaro/components/extractor.py | 2 +- datumaro/datumaro/components/operations.py | 41 +++--- datumaro/datumaro/util/__init__.py | 3 + datumaro/tests/test_diff.py | 127 ++++++++++-------- 5 files changed, 103 insertions(+), 89 deletions(-) diff --git a/datumaro/datumaro/cli/contexts/project/__init__.py b/datumaro/datumaro/cli/contexts/project/__init__.py index d36c40dc5f5d..18d18bc883d0 100644 --- a/datumaro/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/datumaro/cli/contexts/project/__init__.py @@ -566,15 +566,15 @@ def diff_command(args): return 0 -def build_diff2_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor(help="Compare projects", +def build_ediff_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Compare projects for equality", description=""" - Compares two projects for exact equality.|n + Compares two projects for equality.|n |n Examples:|n - Compare two projects, exclude annotation group |n |s|s|sand the 'is_crowd' attribute from comparison:|n - |s|sdiff2 other/project/ -if group -ia is_crowd + |s|sediff other/project/ -if group -ia is_crowd """, formatter_class=MultilineFormatter) @@ -584,17 +584,18 @@ def build_diff2_parser(parser_ctor=argparse.ArgumentParser): help="Ignore an item attribute (repeatable)") parser.add_argument('-ia', '--ignore-attr', action='append', help="Ignore an annotation attribute (repeatable)") - parser.add_argument('-if', '--ignore-field', action='append', - help="Ignore an annotation field (repeatable)") + parser.add_argument('-if', '--ignore-field', + action='append', default=['id', 'group'], + help="Ignore an annotation field (repeatable, default: %(default)s)") parser.add_argument('--all', action='store_true', help="Include matches in the output") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the first project to be compared (default: current dir)") - parser.set_defaults(command=diff2_command) + parser.set_defaults(command=ediff_command) return parser -def diff2_command(args): +def ediff_command(args): first_project = load_project(args.project_dir) second_project = load_project(args.other_project_dir) @@ -815,7 +816,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): add_subparser(subparsers, 'extract', build_extract_parser) add_subparser(subparsers, 'merge', build_merge_parser) add_subparser(subparsers, 'diff', build_diff_parser) - add_subparser(subparsers, 'diff2', build_diff2_parser) + add_subparser(subparsers, 'ediff', build_ediff_parser) add_subparser(subparsers, 'transform', build_transform_parser) add_subparser(subparsers, 'info', build_info_parser) add_subparser(subparsers, 'stats', build_stats_parser) diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index 95d43570d512..cda20f3d0c62 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -46,7 +46,7 @@ def wrap(item, **kwargs): @attrs class Categories: attributes = attrib(factory=set, validator=default_if_none(set), - kw_only=True) + kw_only=True, eq=False) @attrs class LabelCategories(Categories): diff --git a/datumaro/datumaro/components/operations.py b/datumaro/datumaro/components/operations.py index 339be1f74dca..b6a02c850ed2 100644 --- a/datumaro/datumaro/components/operations.py +++ b/datumaro/datumaro/components/operations.py @@ -17,7 +17,7 @@ from datumaro.components.extractor import (AnnotationType, Bbox, Label, LabelCategories) from datumaro.components.project import Dataset -from datumaro.util import find, ensure_cls +from datumaro.util import find, ensure_cls, filter_dict from datumaro.util.annotation_util import (segment_iou, bbox_iou, mean_bbox, OKS, find_instances, max_bbox, smooth_line) @@ -1049,7 +1049,7 @@ def __attrs_post_init__(self): @staticmethod - def match_datasets(a, b): + def _match_datasets(a, b): a_items = set((item.id, item.subset) for item in a) b_items = set((item.id, item.subset) for item in b) @@ -1058,7 +1058,7 @@ def match_datasets(a, b): b_unmatched = b_items - a_items return matches, a_unmatched, b_unmatched - def compare_categories(self, a, b): + def _compare_categories(self, a, b): test = self._test errors = [] @@ -1096,19 +1096,17 @@ def compare_categories(self, a, b): errors.append({'type': 'points', 'message': str(e)}) return errors - def compare_annotations(self, a, b): + def _compare_annotations(self, a, b): ignored_fields = self.ignored_fields ignored_attrs = self.ignored_attrs - a_fields = { k: v for k, v in vars(a).items() if k not in ignored_fields } - b_fields = { k: v for k, v in vars(b).items() if k not in ignored_fields } - - a_fields['attributes'] = { k: v for k, v in a_fields['attributes'].items() - if k not in ignored_attrs } - b_fields['attributes'] = { k: v for k, v in b_fields['attributes'].items() - if k not in ignored_attrs } + a_fields = { k: None for k in vars(a) if k in ignored_fields} + b_fields = { k: None for k in vars(b) if k in ignored_fields} + if 'attributes' not in ignored_fields: + a_fields['attributes'] = filter_dict(a.attributes, ignored_attrs) + b_fields['attributes'] = filter_dict(b.attributes, ignored_attrs) - result = a_fields == b_fields + result = a.wrap(**a_fields) == b.wrap(**b_fields) return result @@ -1117,22 +1115,25 @@ def compare_datasets(self, a, b): errors = [] - errors.append(self.compare_categories(a.categories(), b.categories())) + errors.extend(self._compare_categories(a.categories(), b.categories())) matched = [] unmatched = [] - items, a_extra_items, b_extra_items = self.match_datasets(a, b) + items, a_extra_items, b_extra_items = self._match_datasets(a, b) + + if a.categories().get(AnnotationType.label) != \ + b.categories().get(AnnotationType.label): + return matched, unmatched, a_extra_items, b_extra_items, errors + for item_id in items: item_a = a.get(*item_id) item_b = b.get(*item_id) try: test.assertEqual( - { k: v for k, v in item_a.attributes.items() - if k not in self.ignored_item_attrs }, - { k: v for k, v in item_b.attributes.items() - if k not in self.ignored_item_attrs } + filter_dict(item_a.attributes, self.ignored_item_attrs), + filter_dict(item_b.attributes, self.ignored_item_attrs) ) except AssertionError as e: errors.append({'type': 'item_attr', @@ -1143,7 +1144,7 @@ def compare_datasets(self, a, b): ann_b_candidates = [x for x in item_b.annotations if x.type == ann_a.type] - ann_b = find(enumerate(self.compare_annotations(ann_a, x) + ann_b = find(enumerate(self._compare_annotations(ann_a, x) for x in ann_b_candidates), lambda x: x[1]) if ann_b is None: unmatched.append({ @@ -1159,4 +1160,4 @@ def compare_datasets(self, a, b): for ann_b in b_annotations: unmatched.append({'item': item_id, 'source': 'b', 'ann': str(ann_b)}) - return matched, unmatched, a_extra_items, b_extra_items, errors + return matched, unmatched, a_extra_items, b_extra_items, errors \ No newline at end of file diff --git a/datumaro/datumaro/util/__init__.py b/datumaro/datumaro/util/__init__.py index 6c194ba3b09a..dd3e0c210334 100644 --- a/datumaro/datumaro/util/__init__.py +++ b/datumaro/datumaro/util/__init__.py @@ -96,3 +96,6 @@ def converter(arg): else: return c(**arg) return converter + +def filter_dict(d, exclude_keys): + return { k: v for k, v in d.items() if k not in exclude_keys } \ No newline at end of file diff --git a/datumaro/tests/test_diff.py b/datumaro/tests/test_diff.py index 187269d8d9f1..4ea145af58ae 100644 --- a/datumaro/tests/test_diff.py +++ b/datumaro/tests/test_diff.py @@ -1,8 +1,11 @@ -from unittest import TestCase +import numpy as np -from datumaro.components.extractor import DatasetItem, Label, Bbox +from datumaro.components.extractor import DatasetItem, Label, Bbox, Caption, Mask, Points +from datumaro.components.project import Dataset from datumaro.components.operations import DistanceComparator, ExactComparator +from unittest import TestCase + class DistanceComparatorTest(TestCase): def test_no_bbox_diff_with_same_item(self): @@ -109,60 +112,66 @@ def test_can_find_wrong_label(self): self.assertEqual(2, len(b_greater)) self.assertEqual(1, len(matches)) -# class ExactComparatorTest(TestCase): -# def test_ - - - - - -# label_categories = LabelCategories() -# for i in range(5): -# label_categories.add('cat' + str(i)) - -# mask_categories = MaskCategories( -# generate_colormap(len(label_categories.items))) - -# points_categories = PointsCategories() -# for index, _ in enumerate(label_categories.items): -# points_categories.add(index, ['cat1', 'cat2'], joints=[[0, 1]]) - -# return Dataset.from_iterable([ -# DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), -# annotations=[ -# Caption('hello', id=1), -# Caption('world', id=2, group=5), -# Label(2, id=3, attributes={ -# 'x': 1, -# 'y': '2', -# }), -# Bbox(1, 2, 3, 4, label=4, id=4, z_order=1, attributes={ -# 'score': 1.0, -# }), -# Bbox(5, 6, 7, 8, id=5, group=5), -# Points([1, 2, 2, 0, 1, 1], label=0, id=5, z_order=4), -# Mask(label=3, id=5, z_order=2, image=np.ones((2, 3))), -# ]), -# DatasetItem(id=21, subset='train', -# annotations=[ -# Caption('test'), -# Label(2), -# Bbox(1, 2, 3, 4, label=5, id=42, group=42) -# ]), - -# DatasetItem(id=2, subset='val', -# annotations=[ -# PolyLine([1, 2, 3, 4, 5, 6, 7, 8], id=11, z_order=1), -# Polygon([1, 2, 3, 4, 5, 6, 7, 8], id=12, z_order=4), -# ]), - -# DatasetItem(id=42, subset='test', -# attributes={'a1': 5, 'a2': '42'}), - -# DatasetItem(id=42), -# DatasetItem(id=43, image=Image(path='1/b/c.qq', size=(2, 4))), -# ], categories={ -# AnnotationType.label: label_categories, -# AnnotationType.mask: mask_categories, -# AnnotationType.points: points_categories, -# }) \ No newline at end of file +class ExactComparatorTest(TestCase): + def test_class_comparison(self): + a = Dataset.from_iterable([], categories=['a', 'b', 'c']) + b = Dataset.from_iterable([], categories=['b', 'c']) + + comp = ExactComparator() + _, _, _, _, errors = comp.compare_datasets(a, b) + + self.assertEqual(1, len(errors), errors) + + def test_item_comparison(self): + a = Dataset.from_iterable([ + DatasetItem(id=1, subset='train'), + DatasetItem(id=2, subset='test', attributes={'x': 1}), + ], categories=['a', 'b', 'c']) + + b = Dataset.from_iterable([ + DatasetItem(id=2, subset='test'), + DatasetItem(id=3), + ], categories=['a', 'b', 'c']) + + comp = ExactComparator() + _, _, a_extra_items, b_extra_items, errors = comp.compare_datasets(a, b) + + self.assertEqual({('1', 'train')}, a_extra_items) + self.assertEqual({('3', '')}, b_extra_items) + self.assertEqual(1, len(errors), errors) + + def test_annotation_comparison(self): + a = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + Caption('hello'), # unmatched + Caption('world', group=5), + Label(2, attributes={ 'x': 1, 'y': '2', }), + Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={ + 'score': 1.0, + }), + Bbox(5, 6, 7, 8, group=5), + Points([1, 2, 2, 0, 1, 1], label=0, z_order=4), + Mask(label=3, z_order=2, image=np.ones((2, 3))), + ]), + ], categories=['a', 'b', 'c', 'd']) + + b = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + Caption('world', group=5), + Label(2, attributes={ 'x': 1, 'y': '2', }), + Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={ + 'score': 1.0, + }), + Bbox(5, 6, 7, 8, group=5), + Bbox(5, 6, 7, 8, group=5), # unmatched + Points([1, 2, 2, 0, 1, 1], label=0, z_order=4), + Mask(label=3, z_order=2, image=np.ones((2, 3))), + ]), + ], categories=['a', 'b', 'c', 'd']) + + comp = ExactComparator() + matched, unmatched, _, _, errors = comp.compare_datasets(a, b) + + self.assertEqual(6, len(matched), matched) + self.assertEqual(2, len(unmatched), unmatched) + self.assertEqual(0, len(errors), errors) \ No newline at end of file