Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max committed Aug 12, 2020
1 parent 68193e4 commit 0e670a5
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 89 deletions.
19 changes: 10 additions & 9 deletions datumaro/datumaro/cli/contexts/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,15 +566,15 @@ def diff_command(args):

return 0

def build_diff2_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(help="Compare projects",
def build_ediff_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(help="Compare projects for equality",
description="""
Compares two projects for exact equality.|n
Compares two projects for equality.|n
|n
Examples:|n
- Compare two projects, exclude annotation group |n
|s|s|sand the 'is_crowd' attribute from comparison:|n
|s|sdiff2 other/project/ -if group -ia is_crowd
|s|sediff other/project/ -if group -ia is_crowd
""",
formatter_class=MultilineFormatter)

Expand All @@ -584,17 +584,18 @@ def build_diff2_parser(parser_ctor=argparse.ArgumentParser):
help="Ignore an item attribute (repeatable)")
parser.add_argument('-ia', '--ignore-attr', action='append',
help="Ignore an annotation attribute (repeatable)")
parser.add_argument('-if', '--ignore-field', action='append',
help="Ignore an annotation field (repeatable)")
parser.add_argument('-if', '--ignore-field',
action='append', default=['id', 'group'],
help="Ignore an annotation field (repeatable, default: %(default)s)")
parser.add_argument('--all', action='store_true',
help="Include matches in the output")
parser.add_argument('-p', '--project', dest='project_dir', default='.',
help="Directory of the first project to be compared (default: current dir)")
parser.set_defaults(command=diff2_command)
parser.set_defaults(command=ediff_command)

return parser

def diff2_command(args):
def ediff_command(args):
first_project = load_project(args.project_dir)
second_project = load_project(args.other_project_dir)

Expand Down Expand Up @@ -815,7 +816,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
add_subparser(subparsers, 'extract', build_extract_parser)
add_subparser(subparsers, 'merge', build_merge_parser)
add_subparser(subparsers, 'diff', build_diff_parser)
add_subparser(subparsers, 'diff2', build_diff2_parser)
add_subparser(subparsers, 'ediff', build_ediff_parser)
add_subparser(subparsers, 'transform', build_transform_parser)
add_subparser(subparsers, 'info', build_info_parser)
add_subparser(subparsers, 'stats', build_stats_parser)
Expand Down
2 changes: 1 addition & 1 deletion datumaro/datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def wrap(item, **kwargs):
@attrs
class Categories:
attributes = attrib(factory=set, validator=default_if_none(set),
kw_only=True)
kw_only=True, eq=False)

@attrs
class LabelCategories(Categories):
Expand Down
41 changes: 21 additions & 20 deletions datumaro/datumaro/components/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from datumaro.components.extractor import (AnnotationType, Bbox, Label,
LabelCategories)
from datumaro.components.project import Dataset
from datumaro.util import find, ensure_cls
from datumaro.util import find, ensure_cls, filter_dict
from datumaro.util.annotation_util import (segment_iou, bbox_iou,
mean_bbox, OKS, find_instances, max_bbox, smooth_line)

Expand Down Expand Up @@ -1049,7 +1049,7 @@ def __attrs_post_init__(self):


@staticmethod
def match_datasets(a, b):
def _match_datasets(a, b):
a_items = set((item.id, item.subset) for item in a)
b_items = set((item.id, item.subset) for item in b)

Expand All @@ -1058,7 +1058,7 @@ def match_datasets(a, b):
b_unmatched = b_items - a_items
return matches, a_unmatched, b_unmatched

def compare_categories(self, a, b):
def _compare_categories(self, a, b):
test = self._test

errors = []
Expand Down Expand Up @@ -1096,19 +1096,17 @@ def compare_categories(self, a, b):
errors.append({'type': 'points', 'message': str(e)})
return errors

def compare_annotations(self, a, b):
def _compare_annotations(self, a, b):
ignored_fields = self.ignored_fields
ignored_attrs = self.ignored_attrs

a_fields = { k: v for k, v in vars(a).items() if k not in ignored_fields }
b_fields = { k: v for k, v in vars(b).items() if k not in ignored_fields }

a_fields['attributes'] = { k: v for k, v in a_fields['attributes'].items()
if k not in ignored_attrs }
b_fields['attributes'] = { k: v for k, v in b_fields['attributes'].items()
if k not in ignored_attrs }
a_fields = { k: None for k in vars(a) if k in ignored_fields}
b_fields = { k: None for k in vars(b) if k in ignored_fields}
if 'attributes' not in ignored_fields:
a_fields['attributes'] = filter_dict(a.attributes, ignored_attrs)
b_fields['attributes'] = filter_dict(b.attributes, ignored_attrs)

result = a_fields == b_fields
result = a.wrap(**a_fields) == b.wrap(**b_fields)

return result

Expand All @@ -1117,22 +1115,25 @@ def compare_datasets(self, a, b):

errors = []

errors.append(self.compare_categories(a.categories(), b.categories()))
errors.extend(self._compare_categories(a.categories(), b.categories()))

matched = []
unmatched = []

items, a_extra_items, b_extra_items = self.match_datasets(a, b)
items, a_extra_items, b_extra_items = self._match_datasets(a, b)

if a.categories().get(AnnotationType.label) != \
b.categories().get(AnnotationType.label):
return matched, unmatched, a_extra_items, b_extra_items, errors

for item_id in items:
item_a = a.get(*item_id)
item_b = b.get(*item_id)

try:
test.assertEqual(
{ k: v for k, v in item_a.attributes.items()
if k not in self.ignored_item_attrs },
{ k: v for k, v in item_b.attributes.items()
if k not in self.ignored_item_attrs }
filter_dict(item_a.attributes, self.ignored_item_attrs),
filter_dict(item_b.attributes, self.ignored_item_attrs)
)
except AssertionError as e:
errors.append({'type': 'item_attr',
Expand All @@ -1143,7 +1144,7 @@ def compare_datasets(self, a, b):
ann_b_candidates = [x for x in item_b.annotations
if x.type == ann_a.type]

ann_b = find(enumerate(self.compare_annotations(ann_a, x)
ann_b = find(enumerate(self._compare_annotations(ann_a, x)
for x in ann_b_candidates), lambda x: x[1])
if ann_b is None:
unmatched.append({
Expand All @@ -1159,4 +1160,4 @@ def compare_datasets(self, a, b):
for ann_b in b_annotations:
unmatched.append({'item': item_id, 'source': 'b', 'ann': str(ann_b)})

return matched, unmatched, a_extra_items, b_extra_items, errors
return matched, unmatched, a_extra_items, b_extra_items, errors
3 changes: 3 additions & 0 deletions datumaro/datumaro/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ def converter(arg):
else:
return c(**arg)
return converter

def filter_dict(d, exclude_keys):
return { k: v for k, v in d.items() if k not in exclude_keys }
127 changes: 68 additions & 59 deletions datumaro/tests/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from unittest import TestCase
import numpy as np

from datumaro.components.extractor import DatasetItem, Label, Bbox
from datumaro.components.extractor import DatasetItem, Label, Bbox, Caption, Mask, Points
from datumaro.components.project import Dataset
from datumaro.components.operations import DistanceComparator, ExactComparator

from unittest import TestCase


class DistanceComparatorTest(TestCase):
def test_no_bbox_diff_with_same_item(self):
Expand Down Expand Up @@ -109,60 +112,66 @@ def test_can_find_wrong_label(self):
self.assertEqual(2, len(b_greater))
self.assertEqual(1, len(matches))

# class ExactComparatorTest(TestCase):
# def test_





# label_categories = LabelCategories()
# for i in range(5):
# label_categories.add('cat' + str(i))

# mask_categories = MaskCategories(
# generate_colormap(len(label_categories.items)))

# points_categories = PointsCategories()
# for index, _ in enumerate(label_categories.items):
# points_categories.add(index, ['cat1', 'cat2'], joints=[[0, 1]])

# return Dataset.from_iterable([
# DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)),
# annotations=[
# Caption('hello', id=1),
# Caption('world', id=2, group=5),
# Label(2, id=3, attributes={
# 'x': 1,
# 'y': '2',
# }),
# Bbox(1, 2, 3, 4, label=4, id=4, z_order=1, attributes={
# 'score': 1.0,
# }),
# Bbox(5, 6, 7, 8, id=5, group=5),
# Points([1, 2, 2, 0, 1, 1], label=0, id=5, z_order=4),
# Mask(label=3, id=5, z_order=2, image=np.ones((2, 3))),
# ]),
# DatasetItem(id=21, subset='train',
# annotations=[
# Caption('test'),
# Label(2),
# Bbox(1, 2, 3, 4, label=5, id=42, group=42)
# ]),

# DatasetItem(id=2, subset='val',
# annotations=[
# PolyLine([1, 2, 3, 4, 5, 6, 7, 8], id=11, z_order=1),
# Polygon([1, 2, 3, 4, 5, 6, 7, 8], id=12, z_order=4),
# ]),

# DatasetItem(id=42, subset='test',
# attributes={'a1': 5, 'a2': '42'}),

# DatasetItem(id=42),
# DatasetItem(id=43, image=Image(path='1/b/c.qq', size=(2, 4))),
# ], categories={
# AnnotationType.label: label_categories,
# AnnotationType.mask: mask_categories,
# AnnotationType.points: points_categories,
# })
class ExactComparatorTest(TestCase):
def test_class_comparison(self):
a = Dataset.from_iterable([], categories=['a', 'b', 'c'])
b = Dataset.from_iterable([], categories=['b', 'c'])

comp = ExactComparator()
_, _, _, _, errors = comp.compare_datasets(a, b)

self.assertEqual(1, len(errors), errors)

def test_item_comparison(self):
a = Dataset.from_iterable([
DatasetItem(id=1, subset='train'),
DatasetItem(id=2, subset='test', attributes={'x': 1}),
], categories=['a', 'b', 'c'])

b = Dataset.from_iterable([
DatasetItem(id=2, subset='test'),
DatasetItem(id=3),
], categories=['a', 'b', 'c'])

comp = ExactComparator()
_, _, a_extra_items, b_extra_items, errors = comp.compare_datasets(a, b)

self.assertEqual({('1', 'train')}, a_extra_items)
self.assertEqual({('3', '')}, b_extra_items)
self.assertEqual(1, len(errors), errors)

def test_annotation_comparison(self):
a = Dataset.from_iterable([
DatasetItem(id=1, annotations=[
Caption('hello'), # unmatched
Caption('world', group=5),
Label(2, attributes={ 'x': 1, 'y': '2', }),
Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={
'score': 1.0,
}),
Bbox(5, 6, 7, 8, group=5),
Points([1, 2, 2, 0, 1, 1], label=0, z_order=4),
Mask(label=3, z_order=2, image=np.ones((2, 3))),
]),
], categories=['a', 'b', 'c', 'd'])

b = Dataset.from_iterable([
DatasetItem(id=1, annotations=[
Caption('world', group=5),
Label(2, attributes={ 'x': 1, 'y': '2', }),
Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={
'score': 1.0,
}),
Bbox(5, 6, 7, 8, group=5),
Bbox(5, 6, 7, 8, group=5), # unmatched
Points([1, 2, 2, 0, 1, 1], label=0, z_order=4),
Mask(label=3, z_order=2, image=np.ones((2, 3))),
]),
], categories=['a', 'b', 'c', 'd'])

comp = ExactComparator()
matched, unmatched, _, _, errors = comp.compare_datasets(a, b)

self.assertEqual(6, len(matched), matched)
self.assertEqual(2, len(unmatched), unmatched)
self.assertEqual(0, len(errors), errors)

0 comments on commit 0e670a5

Please sign in to comment.