From fd8ff3594241691430c1737a38e0a8b30fd3e6fa Mon Sep 17 00:00:00 2001 From: Zhiltsov Max Date: Tue, 18 Aug 2020 17:38:21 +0300 Subject: [PATCH] Add exact diff command --- .../datumaro/cli/contexts/project/__init__.py | 100 ++++++-- .../datumaro/cli/contexts/project/diff.py | 2 +- datumaro/datumaro/components/comparator.py | 113 --------- datumaro/datumaro/components/extractor.py | 2 +- datumaro/datumaro/components/operations.py | 229 ++++++++++++++++++ datumaro/datumaro/util/__init__.py | 11 + datumaro/datumaro/util/test_utils.py | 3 +- datumaro/tests/test_diff.py | 157 +++++++----- 8 files changed, 421 insertions(+), 196 deletions(-) delete mode 100644 datumaro/datumaro/components/comparator.py diff --git a/datumaro/datumaro/cli/contexts/project/__init__.py b/datumaro/datumaro/cli/contexts/project/__init__.py index e6d5809b5416..d28b1bfc41e6 100644 --- a/datumaro/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/datumaro/cli/contexts/project/__init__.py @@ -4,25 +4,26 @@ # SPDX-License-Identifier: MIT import argparse -from enum import Enum import json import logging as log import os import os.path as osp import shutil +from enum import Enum -from datumaro.components.project import Project, Environment, \ - PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG -from datumaro.components.comparator import Comparator +from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset_filter import DatasetItemEncoder from datumaro.components.extractor import AnnotationType -from datumaro.components.cli_plugin import CliPlugin -from datumaro.components.operations import \ - compute_image_statistics, compute_ann_statistics +from datumaro.components.operations import (DistanceComparator, + ExactComparator, compute_ann_statistics, compute_image_statistics, mean_std) +from datumaro.components.project import \ + PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG +from datumaro.components.project import Environment, Project + +from ...util import (CliException, MultilineFormatter, add_subparser, + make_file_name) +from ...util.project import generate_next_file_name, load_project from .diff import DiffVisualizer -from ...util import add_subparser, CliException, MultilineFormatter, \ - make_file_name -from ...util.project import load_project, generate_next_file_name def build_create_parser(parser_ctor=argparse.ArgumentParser): @@ -503,12 +504,12 @@ def merge_command(args): def build_diff_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(help="Compare projects", description=""" - Compares two projects.|n + Compares two projects, match annotations by distance.|n |n Examples:|n - - Compare two projects, consider bboxes matching if their IoU > 0.7,|n + - Compare two projects, match boxes if IoU > 0.7,|n |s|s|s|sprint results to Tensorboard: - |s|sdiff path/to/other/project -o diff/ -f tensorboard --iou-thresh 0.7 + |s|sdiff path/to/other/project -o diff/ -v tensorboard --iou-thresh 0.7 """, formatter_class=MultilineFormatter) @@ -516,7 +517,7 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser): help="Directory of the second project to be compared") parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, help="Directory to save comparison results (default: do not save)") - parser.add_argument('-f', '--format', + parser.add_argument('-v', '--visualizer', default=DiffVisualizer.DEFAULT_FORMAT, choices=[f.name for f in DiffVisualizer.Format], help="Output format (default: %(default)s)") @@ -536,9 +537,7 @@ def diff_command(args): first_project = load_project(args.project_dir) second_project = load_project(args.other_project_dir) - comparator = Comparator( - iou_threshold=args.iou_thresh, - conf_threshold=args.conf_thresh) + comparator = DistanceComparator(iou_threshold=args.iou_thresh) dst_dir = args.dst_dir if dst_dir: @@ -556,7 +555,7 @@ def diff_command(args): dst_dir_existed = osp.exists(dst_dir) try: visualizer = DiffVisualizer(save_dir=dst_dir, comparator=comparator, - output_format=args.format) + output_format=args.visualizer) visualizer.save_dataset_diff( first_project.make_dataset(), second_project.make_dataset()) @@ -567,6 +566,70 @@ def diff_command(args): return 0 +def build_ediff_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Compare projects for equality", + description=""" + Compares two projects for equality.|n + |n + Examples:|n + - Compare two projects, exclude annotation group |n + |s|s|sand the 'is_crowd' attribute from comparison:|n + |s|sediff other/project/ -if group -ia is_crowd + """, + formatter_class=MultilineFormatter) + + parser.add_argument('other_project_dir', + help="Directory of the second project to be compared") + parser.add_argument('-iia', '--ignore-item-attr', action='append', + help="Ignore an item attribute (repeatable)") + parser.add_argument('-ia', '--ignore-attr', action='append', + help="Ignore an annotation attribute (repeatable)") + parser.add_argument('-if', '--ignore-field', + action='append', default=['id', 'group'], + help="Ignore an annotation field (repeatable, default: %(default)s)") + parser.add_argument('--all', action='store_true', + help="Include matches in the output") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the first project to be compared (default: current dir)") + parser.set_defaults(command=ediff_command) + + return parser + +def ediff_command(args): + first_project = load_project(args.project_dir) + second_project = load_project(args.other_project_dir) + + comparator = ExactComparator( + ignored_fields=args.ignore_field or [], + ignored_attrs=args.ignore_attr or [], + ignored_item_attrs=args.ignore_item_attr or []) + matches, mismatches, a_extra, b_extra, errors = \ + comparator.compare_datasets( + first_project.make_dataset(), second_project.make_dataset()) + output = { + "mismatches": mismatches, + "a_extra_items": sorted(a_extra), + "b_extra_items": sorted(b_extra), + "errors": errors, + } + if args.all: + output["matches"] = matches + + output_file = generate_next_file_name('diff', ext='.json') + with open(output_file, 'w') as f: + json.dump(output, f, indent=4, sort_keys=True) + + print("Found:") + print("The first project has %s unmatched items" % len(a_extra)) + print("The second project has %s unmatched items" % len(b_extra)) + print("%s item conflicts" % len(errors)) + print("%s matching annotations" % len(matches)) + print("%s mismatching annotations" % len(mismatches)) + + log.info("Output has been saved to '%s'" % output_file) + + return 0 + def build_transform_parser(parser_ctor=argparse.ArgumentParser): builtins = sorted(Environment().transforms.items) @@ -753,6 +816,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): add_subparser(subparsers, 'extract', build_extract_parser) add_subparser(subparsers, 'merge', build_merge_parser) add_subparser(subparsers, 'diff', build_diff_parser) + add_subparser(subparsers, 'ediff', build_ediff_parser) add_subparser(subparsers, 'transform', build_transform_parser) add_subparser(subparsers, 'info', build_info_parser) add_subparser(subparsers, 'stats', build_stats_parser) diff --git a/datumaro/datumaro/cli/contexts/project/diff.py b/datumaro/datumaro/cli/contexts/project/diff.py index 785c6c8ecde7..571908f66794 100644 --- a/datumaro/datumaro/cli/contexts/project/diff.py +++ b/datumaro/datumaro/cli/contexts/project/diff.py @@ -217,7 +217,7 @@ def save_item_bbox_diff(self, item_a, item_b, diff): _, mispred, a_unmatched, b_unmatched = diff if 0 < len(a_unmatched) + len(b_unmatched) + len(mispred): - img_a = item_a.image.copy() + img_a = item_a.image.data.copy() img_b = img_a.copy() for a_bbox, b_bbox in mispred: self.draw_bbox(img_a, a_bbox, (0, 255, 0)) diff --git a/datumaro/datumaro/components/comparator.py b/datumaro/datumaro/components/comparator.py deleted file mode 100644 index 842a3963a989..000000000000 --- a/datumaro/datumaro/components/comparator.py +++ /dev/null @@ -1,113 +0,0 @@ - -# Copyright (C) 2019 Intel Corporation -# -# SPDX-License-Identifier: MIT - -from itertools import zip_longest -import numpy as np - -from datumaro.components.extractor import AnnotationType, LabelCategories - - -class Comparator: - def __init__(self, - iou_threshold=0.5, conf_threshold=0.9): - self.iou_threshold = iou_threshold - self.conf_threshold = conf_threshold - - @staticmethod - def iou(box_a, box_b): - return box_a.iou(box_b) - - # pylint: disable=no-self-use - def compare_dataset_labels(self, extractor_a, extractor_b): - a_label_cat = extractor_a.categories().get(AnnotationType.label) - b_label_cat = extractor_b.categories().get(AnnotationType.label) - if not a_label_cat and not b_label_cat: - return None - if not a_label_cat: - a_label_cat = LabelCategories() - if not b_label_cat: - b_label_cat = LabelCategories() - - mismatches = [] - for a_label, b_label in zip_longest(a_label_cat.items, b_label_cat.items): - if a_label != b_label: - mismatches.append((a_label, b_label)) - return mismatches - # pylint: enable=no-self-use - - def compare_item_labels(self, item_a, item_b): - conf_threshold = self.conf_threshold - - a_labels = set([ann.label for ann in item_a.annotations \ - if ann.type is AnnotationType.label and \ - conf_threshold < ann.attributes.get('score', 1)]) - b_labels = set([ann.label for ann in item_b.annotations \ - if ann.type is AnnotationType.label and \ - conf_threshold < ann.attributes.get('score', 1)]) - - a_unmatched = a_labels - b_labels - b_unmatched = b_labels - a_labels - matches = a_labels & b_labels - - return matches, a_unmatched, b_unmatched - - def compare_item_bboxes(self, item_a, item_b): - iou_threshold = self.iou_threshold - conf_threshold = self.conf_threshold - - a_boxes = [ann for ann in item_a.annotations \ - if ann.type is AnnotationType.bbox and \ - conf_threshold < ann.attributes.get('score', 1)] - b_boxes = [ann for ann in item_b.annotations \ - if ann.type is AnnotationType.bbox and \ - conf_threshold < ann.attributes.get('score', 1)] - a_boxes.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) - b_boxes.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) - - # a_matches: indices of b_boxes matched to a bboxes - # b_matches: indices of a_boxes matched to b bboxes - a_matches = -np.ones(len(a_boxes), dtype=int) - b_matches = -np.ones(len(b_boxes), dtype=int) - - iou_matrix = np.array([ - [self.iou(a, b) for b in b_boxes] for a in a_boxes - ]) - - # matches: boxes we succeeded to match completely - # mispred: boxes we succeeded to match, having label mismatch - matches = [] - mispred = [] - - for a_idx, a_bbox in enumerate(a_boxes): - if len(b_boxes) == 0: - break - matched_b = a_matches[a_idx] - iou_max = max(iou_matrix[a_idx, matched_b], iou_threshold) - for b_idx, b_bbox in enumerate(b_boxes): - if 0 <= b_matches[b_idx]: # assign a_bbox with max conf - continue - iou = iou_matrix[a_idx, b_idx] - if iou < iou_max: - continue - iou_max = iou - matched_b = b_idx - - if matched_b < 0: - continue - a_matches[a_idx] = matched_b - b_matches[matched_b] = a_idx - - b_bbox = b_boxes[matched_b] - - if a_bbox.label == b_bbox.label: - matches.append( (a_bbox, b_bbox) ) - else: - mispred.append( (a_bbox, b_bbox) ) - - # *_umatched: boxes of (*) we failed to match - a_unmatched = [a_boxes[i] for i, m in enumerate(a_matches) if m < 0] - b_unmatched = [b_boxes[i] for i, m in enumerate(b_matches) if m < 0] - - return matches, mispred, a_unmatched, b_unmatched diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index d7991cd121e0..573f8d4ff409 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -46,7 +46,7 @@ def wrap(item, **kwargs): @attrs class Categories: attributes = attrib(factory=set, validator=default_if_none(set), - kw_only=True) + kw_only=True, eq=False) @attrs class LabelCategories(Categories): diff --git a/datumaro/datumaro/components/operations.py b/datumaro/datumaro/components/operations.py index 9e63d3a7e84e..e27a296e357e 100644 --- a/datumaro/datumaro/components/operations.py +++ b/datumaro/datumaro/components/operations.py @@ -1003,3 +1003,232 @@ def get_label(ann): } for c, (bin_min, bin_max) in zip(hist, zip(bins[:-1], bins[1:]))] return stats + +@attrs +class DistanceComparator: + iou_threshold = attrib(converter=float, default=0.5) + + @staticmethod + def match_datasets(a, b): + a_items = set((item.id, item.subset) for item in a) + b_items = set((item.id, item.subset) for item in b) + + matches = a_items & b_items + a_unmatched = a_items - b_items + b_unmatched = b_items - a_items + return matches, a_unmatched, b_unmatched + + @staticmethod + def match_classes(a, b): + a_label_cat = a.categories().get(AnnotationType.label, LabelCategories()) + b_label_cat = b.categories().get(AnnotationType.label, LabelCategories()) + + a_labels = set(c.name for c in a_label_cat) + b_labels = set(c.name for c in b_label_cat) + + matches = a_labels & b_labels + a_unmatched = a_labels - b_labels + b_unmatched = b_labels - a_labels + return matches, a_unmatched, b_unmatched + + def match_annotations(self, item_a, item_b): + return { t: self._match_ann_type(t, item_a, item_b) } + + def _match_ann_type(self, t, *args): + if t == AnnotationType.label: + return self.match_labels(*args) + elif t == AnnotationType.bbox: + return self.match_boxes(*args) + elif t == AnnotationType.polygon: + return self.match_polygons(*args) + elif t == AnnotationType.mask: + return self.match_masks(*args) + elif t == AnnotationType.points: + return self.match_points(*args) + elif t == AnnotationType.polyline: + return self.match_lines(*args) + else: + raise NotImplementedError("Unexpected annotation type %s" % t) + + @staticmethod + def _get_ann_type(t, item): + return get_ann_type(item.annotations, t) + + def match_labels(self, item_a, item_b): + a_labels = set(a.label for a in + self._get_ann_type(AnnotationType.label, item_a)) + b_labels = set(a.label for a in + self._get_ann_type(AnnotationType.label, item_b)) + + matches = a_labels & b_labels + a_unmatched = a_labels - b_labels + b_unmatched = b_labels - a_labels + return matches, a_unmatched, b_unmatched + + def _match_segments(self, t, item_a, item_b): + a_boxes = self._get_ann_type(t, item_a) + b_boxes = self._get_ann_type(t, item_b) + return match_segments(a_boxes, b_boxes, dist_thresh=self.iou_threshold) + + def match_polygons(self, item_a, item_b): + return self._match_segments(AnnotationType.polygon, item_a, item_b) + + def match_masks(self, item_a, item_b): + return self._match_segments(AnnotationType.mask, item_a, item_b) + + def match_boxes(self, item_a, item_b): + return self._match_segments(AnnotationType.bbox, item_a, item_b) + + def match_points(self, item_a, item_b): + a_points = self._get_ann_type(AnnotationType.points, item_a) + b_points = self._get_ann_type(AnnotationType.points, item_b) + + instance_map = {} + for s in sources: + s_instances = find_instances(s) + for inst in s_instances: + inst_bbox = max_bbox(inst) + for ann in inst: + instance_map[id(ann)] = [inst, inst_bbox] + matcher = PointsMatcher(instance_map=instance_map) + distance = lambda a, b: matcher.distance(a, b) + + return match_segments(a_points, b_points, + dist_thresh=self.iou_threshold, distance=distance) + + def match_lines(self, item_a, item_b): + a_lines = self._get_ann_type(AnnotationType.polyline, item_a) + b_lines = self._get_ann_type(AnnotationType.polyline, item_b) + + matcher = LineMatcher() + distance = lambda a, b: matcher.distance(a, b) + + return match_segments(a_lines, b_lines, + dist_thresh=self.iou_threshold, distance=distance) + +@attrs +class ExactComparator: + ignored_fields = attrib(kw_only=True, factory=set, converter=set) + ignored_attrs = attrib(kw_only=True, factory=set, converter=set) + ignored_item_attrs = attrib(kw_only=True, factory=set, converter=set) + + _test = attrib(init=False, type=TestCase) + + def __attrs_post_init__(self): + self._test = TestCase() + self._test.maxDiff = None + + + @staticmethod + def _match_datasets(a, b): + a_items = set((item.id, item.subset) for item in a) + b_items = set((item.id, item.subset) for item in b) + + matches = a_items & b_items + a_unmatched = a_items - b_items + b_unmatched = b_items - a_items + return matches, a_unmatched, b_unmatched + + def _compare_categories(self, a, b): + test = self._test + + errors = [] + try: + test.assertEqual( + sorted(a, key=lambda t: t.value), + sorted(b, key=lambda t: t.value) + ) + except AssertionError as e: + errors.append({'type': 'categories', 'message': str(e)}) + + if AnnotationType.label in a: + try: + test.assertEqual( + a[AnnotationType.label].items, + b[AnnotationType.label].items, + ) + except AssertionError as e: + errors.append({'type': 'labels', 'message': str(e)}) + if AnnotationType.mask in a: + try: + test.assertEqual( + a[AnnotationType.mask].colormap, + b[AnnotationType.mask].colormap, + ) + except AssertionError as e: + errors.append({'type': 'colormap', 'message': str(e)}) + if AnnotationType.points in a: + try: + test.assertEqual( + a[AnnotationType.points].items, + b[AnnotationType.points].items, + ) + except AssertionError as e: + errors.append({'type': 'points', 'message': str(e)}) + return errors + + def _compare_annotations(self, a, b): + ignored_fields = self.ignored_fields + ignored_attrs = self.ignored_attrs + + a_fields = { k: None for k in vars(a) if k in ignored_fields} + b_fields = { k: None for k in vars(b) if k in ignored_fields} + if 'attributes' not in ignored_fields: + a_fields['attributes'] = filter_dict(a.attributes, ignored_attrs) + b_fields['attributes'] = filter_dict(b.attributes, ignored_attrs) + + result = a.wrap(**a_fields) == b.wrap(**b_fields) + + return result + + def compare_datasets(self, a, b): + test = self._test + + errors = [] + + errors.extend(self._compare_categories(a.categories(), b.categories())) + + matched = [] + unmatched = [] + + items, a_extra_items, b_extra_items = self._match_datasets(a, b) + + if a.categories().get(AnnotationType.label) != \ + b.categories().get(AnnotationType.label): + return matched, unmatched, a_extra_items, b_extra_items, errors + + for item_id in items: + item_a = a.get(*item_id) + item_b = b.get(*item_id) + + try: + test.assertEqual( + filter_dict(item_a.attributes, self.ignored_item_attrs), + filter_dict(item_b.attributes, self.ignored_item_attrs) + ) + except AssertionError as e: + errors.append({'type': 'item_attr', + 'item': item_id, 'message': str(e)}) + + b_annotations = item_b.annotations[:] + for ann_a in item_a.annotations: + ann_b_candidates = [x for x in item_b.annotations + if x.type == ann_a.type] + + ann_b = find(enumerate(self._compare_annotations(ann_a, x) + for x in ann_b_candidates), lambda x: x[1]) + if ann_b is None: + unmatched.append({ + 'item': item_id, 'source': 'a', 'ann': str(ann_a), + }) + continue + else: + ann_b = ann_b_candidates[ann_b[0]] + + b_annotations.remove(ann_b) # avoid repeats + matched.append({'item': item_id, 'a': str(ann_a), 'b': str(ann_b)}) + + for ann_b in b_annotations: + unmatched.append({'item': item_id, 'source': 'b', 'ann': str(ann_b)}) + + return matched, unmatched, a_extra_items, b_extra_items, errors \ No newline at end of file diff --git a/datumaro/datumaro/util/__init__.py b/datumaro/datumaro/util/__init__.py index 293bb5f62f34..dd3e0c210334 100644 --- a/datumaro/datumaro/util/__init__.py +++ b/datumaro/datumaro/util/__init__.py @@ -88,3 +88,14 @@ def str_to_bool(s): return False else: raise ValueError("Can't convert value '%s' to bool" % s) + +def ensure_cls(c): + def converter(arg): + if isinstance(arg, c): + return arg + else: + return c(**arg) + return converter + +def filter_dict(d, exclude_keys): + return { k: v for k, v in d.items() if k not in exclude_keys } \ No newline at end of file diff --git a/datumaro/datumaro/util/test_utils.py b/datumaro/datumaro/util/test_utils.py index f93a74ce1b37..62973ca5a0ae 100644 --- a/datumaro/datumaro/util/test_utils.py +++ b/datumaro/datumaro/util/test_utils.py @@ -100,8 +100,7 @@ def compare_datasets(test, expected, actual, ignored_attrs=None): ann_b = find(ann_b_matches, lambda x: _compare_annotations(x, ann_a, ignored_attrs=ignored_attrs)) if ann_b is None: - test.assertEqual(ann_a, ann_b, - 'ann %s, candidates %s' % (ann_a, ann_b_matches)) + test.fail('ann %s, candidates %s' % (ann_a, ann_b_matches)) item_b.annotations.remove(ann_b) # avoid repeats def compare_datasets_strict(test, expected, actual): diff --git a/datumaro/tests/test_diff.py b/datumaro/tests/test_diff.py index 9ad9c1de6fdf..4ea145af58ae 100644 --- a/datumaro/tests/test_diff.py +++ b/datumaro/tests/test_diff.py @@ -1,123 +1,97 @@ -from unittest import TestCase +import numpy as np + +from datumaro.components.extractor import DatasetItem, Label, Bbox, Caption, Mask, Points +from datumaro.components.project import Dataset +from datumaro.components.operations import DistanceComparator, ExactComparator -from datumaro.components.extractor import DatasetItem, Label, Bbox -from datumaro.components.comparator import Comparator +from unittest import TestCase -class DiffTest(TestCase): +class DistanceComparatorTest(TestCase): def test_no_bbox_diff_with_same_item(self): detections = 3 anns = [ - Bbox(i * 10, 10, 10, 10, label=i, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) + Bbox(i * 10, 10, 10, 10, label=i) + for i in range(detections) ] item = DatasetItem(id=0, annotations=anns) iou_thresh = 0.5 - conf_thresh = 0.5 - comp = Comparator( - iou_threshold=iou_thresh, conf_threshold=conf_thresh) + comp = DistanceComparator(iou_threshold=iou_thresh) - result = comp.compare_item_bboxes(item, item) + result = comp.match_boxes(item, item) matches, mispred, a_greater, b_greater = result self.assertEqual(0, len(mispred)) self.assertEqual(0, len(a_greater)) self.assertEqual(0, len(b_greater)) - self.assertEqual(len([it for it in item.annotations \ - if conf_thresh < it.attributes['score']]), - len(matches)) + self.assertEqual(len(item.annotations), len(matches)) for a_bbox, b_bbox in matches: self.assertLess(iou_thresh, a_bbox.iou(b_bbox)) self.assertEqual(a_bbox.label, b_bbox.label) - self.assertLess(conf_thresh, a_bbox.attributes['score']) - self.assertLess(conf_thresh, b_bbox.attributes['score']) def test_can_find_bbox_with_wrong_label(self): detections = 3 class_count = 2 item1 = DatasetItem(id=1, annotations=[ - Bbox(i * 10, 10, 10, 10, label=i, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) + Bbox(i * 10, 10, 10, 10, label=i) + for i in range(detections) ]) item2 = DatasetItem(id=2, annotations=[ - Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) + Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count) + for i in range(detections) ]) iou_thresh = 0.5 - conf_thresh = 0.5 - comp = Comparator( - iou_threshold=iou_thresh, conf_threshold=conf_thresh) + comp = DistanceComparator(iou_threshold=iou_thresh) - result = comp.compare_item_bboxes(item1, item2) + result = comp.match_boxes(item1, item2) matches, mispred, a_greater, b_greater = result - self.assertEqual(len([it for it in item1.annotations \ - if conf_thresh < it.attributes['score']]), - len(mispred)) + self.assertEqual(len(item1.annotations), len(mispred)) self.assertEqual(0, len(a_greater)) self.assertEqual(0, len(b_greater)) self.assertEqual(0, len(matches)) for a_bbox, b_bbox in mispred: self.assertLess(iou_thresh, a_bbox.iou(b_bbox)) self.assertEqual((a_bbox.label + 1) % class_count, b_bbox.label) - self.assertLess(conf_thresh, a_bbox.attributes['score']) - self.assertLess(conf_thresh, b_bbox.attributes['score']) def test_can_find_missing_boxes(self): detections = 3 class_count = 2 item1 = DatasetItem(id=1, annotations=[ - Bbox(i * 10, 10, 10, 10, label=i, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) if i % 2 == 0 + Bbox(i * 10, 10, 10, 10, label=i) + for i in range(detections) if i % 2 == 0 ]) item2 = DatasetItem(id=2, annotations=[ - Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) if i % 2 == 1 + Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count) + for i in range(detections) if i % 2 == 1 ]) iou_thresh = 0.5 - conf_thresh = 0.5 - comp = Comparator( - iou_threshold=iou_thresh, conf_threshold=conf_thresh) + comp = DistanceComparator(iou_threshold=iou_thresh) - result = comp.compare_item_bboxes(item1, item2) + result = comp.match_boxes(item1, item2) matches, mispred, a_greater, b_greater = result self.assertEqual(0, len(mispred)) - self.assertEqual(len([it for it in item1.annotations \ - if conf_thresh < it.attributes['score']]), - len(a_greater)) - self.assertEqual(len([it for it in item2.annotations \ - if conf_thresh < it.attributes['score']]), - len(b_greater)) + self.assertEqual(len(item1.annotations), len(a_greater)) + self.assertEqual(len(item2.annotations), len(b_greater)) self.assertEqual(0, len(matches)) def test_no_label_diff_with_same_item(self): detections = 3 anns = [ - Label(i, attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) + Label(i) for i in range(detections) ] item = DatasetItem(id=1, annotations=anns) - conf_thresh = 0.5 - comp = Comparator(conf_threshold=conf_thresh) - - result = comp.compare_item_labels(item, item) + result = DistanceComparator().match_labels(item, item) matches, a_greater, b_greater = result self.assertEqual(0, len(a_greater)) self.assertEqual(0, len(b_greater)) - self.assertEqual(len([it for it in item.annotations \ - if conf_thresh < it.attributes['score']]), - len(matches)) + self.assertEqual(len(item.annotations), len(matches)) def test_can_find_wrong_label(self): item1 = DatasetItem(id=1, annotations=[ @@ -131,12 +105,73 @@ def test_can_find_wrong_label(self): Label(4), ]) - conf_thresh = 0.5 - comp = Comparator(conf_threshold=conf_thresh) - - result = comp.compare_item_labels(item1, item2) + result = DistanceComparator().match_labels(item1, item2) matches, a_greater, b_greater = result self.assertEqual(2, len(a_greater)) self.assertEqual(2, len(b_greater)) - self.assertEqual(1, len(matches)) \ No newline at end of file + self.assertEqual(1, len(matches)) + +class ExactComparatorTest(TestCase): + def test_class_comparison(self): + a = Dataset.from_iterable([], categories=['a', 'b', 'c']) + b = Dataset.from_iterable([], categories=['b', 'c']) + + comp = ExactComparator() + _, _, _, _, errors = comp.compare_datasets(a, b) + + self.assertEqual(1, len(errors), errors) + + def test_item_comparison(self): + a = Dataset.from_iterable([ + DatasetItem(id=1, subset='train'), + DatasetItem(id=2, subset='test', attributes={'x': 1}), + ], categories=['a', 'b', 'c']) + + b = Dataset.from_iterable([ + DatasetItem(id=2, subset='test'), + DatasetItem(id=3), + ], categories=['a', 'b', 'c']) + + comp = ExactComparator() + _, _, a_extra_items, b_extra_items, errors = comp.compare_datasets(a, b) + + self.assertEqual({('1', 'train')}, a_extra_items) + self.assertEqual({('3', '')}, b_extra_items) + self.assertEqual(1, len(errors), errors) + + def test_annotation_comparison(self): + a = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + Caption('hello'), # unmatched + Caption('world', group=5), + Label(2, attributes={ 'x': 1, 'y': '2', }), + Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={ + 'score': 1.0, + }), + Bbox(5, 6, 7, 8, group=5), + Points([1, 2, 2, 0, 1, 1], label=0, z_order=4), + Mask(label=3, z_order=2, image=np.ones((2, 3))), + ]), + ], categories=['a', 'b', 'c', 'd']) + + b = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + Caption('world', group=5), + Label(2, attributes={ 'x': 1, 'y': '2', }), + Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={ + 'score': 1.0, + }), + Bbox(5, 6, 7, 8, group=5), + Bbox(5, 6, 7, 8, group=5), # unmatched + Points([1, 2, 2, 0, 1, 1], label=0, z_order=4), + Mask(label=3, z_order=2, image=np.ones((2, 3))), + ]), + ], categories=['a', 'b', 'c', 'd']) + + comp = ExactComparator() + matched, unmatched, _, _, errors = comp.compare_datasets(a, b) + + self.assertEqual(6, len(matched), matched) + self.assertEqual(2, len(unmatched), unmatched) + self.assertEqual(0, len(errors), errors) \ No newline at end of file