From 50d442c8a245f2443a2b015de4295b3ebb19ebac Mon Sep 17 00:00:00 2001 From: Wonju Lee Date: Wed, 17 Apr 2024 21:14:35 +0900 Subject: [PATCH 1/8] temp --- src/datumaro/components/annotation.py | 50 +++++++++++++++++++++++++++ src/datumaro/util/annotation_util.py | 33 +++++++++++++++++- 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/src/datumaro/components/annotation.py b/src/datumaro/components/annotation.py index a0553a00f2..262dff0139 100644 --- a/src/datumaro/components/annotation.py +++ b/src/datumaro/components/annotation.py @@ -48,6 +48,7 @@ class AnnotationType(IntEnum): hash_key = 12 feature_vector = 13 tabular = 14 + rotated_bbox = 15 COORDINATE_ROUNDING_DIGITS = 2 @@ -845,6 +846,55 @@ def wrap(item, **kwargs): return attr.evolve(item, **d) +@attrs(slots=True, init=False, order=False) +class RotatedBbox(_Shape): + _type = AnnotationType.rotated_bbox + + def __init__(self, x, y, w, h, r, *args, **kwargs): + kwargs.pop("points", None) # comes from wrap() + # points = x1, y1, x2, y2, x3, y3, x4, y4 + self.__attrs_init__([x, y, x + w, y + h, r], *args, **kwargs) + + @property + def x(self): + return self.points[0] + + @property + def y(self): + return self.points[1] + + @property + def w(self): + return self.points[2] - self.points[0] + + @property + def h(self): + return self.points[3] - self.points[1] + + @property + def r(self): + return self.points[4] + + def get_area(self): + return self.w * self.h + + def get_bbox(self): + return [self.x, self.y, self.w, self.h] + + def as_polygon(self) -> List[float]: + return self.points + + def iou(self, other: _Shape) -> Union[float, Literal[-1]]: + from datumaro.util.annotation_util import bbox_iou + + return bbox_iou(self.get_bbox(), other.get_bbox()) + + def wrap(item, **kwargs): + d = {"x": item.x, "y": item.y, "w": item.w, "h": item.h, "r": item.r} + d.update(kwargs) + return attr.evolve(item, **d) + + @attrs(slots=True, order=False) class PointsCategories(Categories): """ diff --git a/src/datumaro/util/annotation_util.py b/src/datumaro/util/annotation_util.py index ff371d5a55..3236c9f12e 100644 --- a/src/datumaro/util/annotation_util.py +++ b/src/datumaro/util/annotation_util.py @@ -2,13 +2,21 @@ # # SPDX-License-Identifier: MIT +import math from itertools import groupby from typing import Callable, Dict, Iterable, NewType, Optional, Sequence, Tuple, Union import numpy as np from typing_extensions import Literal -from datumaro.components.annotation import AnnotationType, LabelCategories, Mask, RleMask, _Shape +from datumaro.components.annotation import ( + AnnotationType, + LabelCategories, + Mask, + Points, + RleMask, + _Shape, +) from datumaro.util.mask_tools import mask_to_rle @@ -289,3 +297,26 @@ def map_id(src_id): return id_mapping.get(src_id, fallback) return map_id, id_mapping, source_labels, target_labels + + +def points_to_rotated_bbox(points: Points): + """Convert 8 points representing a rotated bounding box to [top_left_x, top_left_y, width, height, rotation].""" + # Extract individual coordinates from the flat list + x1, y1, x2, y2, x3, y3, _, _ = points # [x1, y1, x1 + w, y1, x1 + w, y1 + h, x1, y1 + h] + + # Calculate rotation angle + angle = math.atan2(y2 - y1, x2 - x1) + + # Calculate the center of the bounding box + center_x = (x1 + x3) / 2 + center_y = (y1 + y3) / 2 + + # Calculate width and height + width = (x3 - x1) / math.cos(angle) + height = (y3 - y1) / math.sin(angle) + + # Calculate top-left corner coordinates + top_left_x = center_x - width / 2 + top_left_y = center_y - height / 2 + + return [top_left_x, top_left_y, width, height, angle] From 3a8b710e547fdbadc690d4b41c66500cd050e958 Mon Sep 17 00:00:00 2001 From: Wonju Lee Date: Wed, 17 Apr 2024 22:54:35 +0900 Subject: [PATCH 2/8] add matcher & merger --- src/datumaro/components/annotations/matcher.py | 6 ++++++ src/datumaro/components/annotations/merger.py | 7 +++++++ src/datumaro/components/merge/intersect_merge.py | 3 +++ 3 files changed, 16 insertions(+) diff --git a/src/datumaro/components/annotations/matcher.py b/src/datumaro/components/annotations/matcher.py index a32bc4fa5e..40af5bb2ff 100644 --- a/src/datumaro/components/annotations/matcher.py +++ b/src/datumaro/components/annotations/matcher.py @@ -367,3 +367,9 @@ def match_annotations(self, sources): class TabularMatcher(AnnotationMatcher): def match_annotations(self, sources): raise NotImplementedError() + + +@attrs +class RotatedBboxMatcher(ShapeMatcher): + def distance(self, a, b): + return OKS(a, b, sigma=self.sigma) diff --git a/src/datumaro/components/annotations/merger.py b/src/datumaro/components/annotations/merger.py index 1e9fbf9bce..c1c356f81b 100644 --- a/src/datumaro/components/annotations/merger.py +++ b/src/datumaro/components/annotations/merger.py @@ -21,6 +21,7 @@ MaskMatcher, PointsMatcher, PolygonMatcher, + RotatedBboxMatcher, ShapeMatcher, TabularMatcher, ) @@ -29,6 +30,7 @@ "AnnotationMerger", "LabelMerger", "BboxMerger", + "RotatedBboxMerger", "PolygonMerger", "MaskMerger", "PointsMerger", @@ -203,3 +205,8 @@ class FeatureVectorMerger(AnnotationMerger, FeatureVectorMatcher): @attrs class TabularMerger(AnnotationMerger, TabularMatcher): pass + + +@attrs +class RotatedBboxMerger(_ShapeMerger, RotatedBboxMatcher): + pass diff --git a/src/datumaro/components/merge/intersect_merge.py b/src/datumaro/components/merge/intersect_merge.py index e6187432a3..ed3352a815 100644 --- a/src/datumaro/components/merge/intersect_merge.py +++ b/src/datumaro/components/merge/intersect_merge.py @@ -29,6 +29,7 @@ MaskMerger, PointsMerger, PolygonMerger, + RotatedBboxMerger, TabularMerger, ) from datumaro.components.dataset_base import DatasetItem, IDataset @@ -452,6 +453,8 @@ def _for_type(t, **kwargs): return _make(FeatureVectorMerger, **kwargs) elif t is AnnotationType.tabular: return _make(TabularMerger, **kwargs) + elif t is AnnotationType.rotated_bbox: + return _make(RotatedBboxMerger, **kwargs) else: raise NotImplementedError("Type %s is not supported" % t) From f2791c428c1ce9eeaedbddfa0fc3f75a96f80a43 Mon Sep 17 00:00:00 2001 From: Wonju Lee Date: Fri, 19 Apr 2024 01:41:09 +0900 Subject: [PATCH 3/8] add matcher & merger --- src/datumaro/components/annotation.py | 42 ++++++++++++++++--- .../components/annotations/matcher.py | 6 ++- src/datumaro/components/task.py | 5 ++- tests/unit/test_ops.py | 16 +++++++ 4 files changed, 59 insertions(+), 10 deletions(-) diff --git a/src/datumaro/components/annotation.py b/src/datumaro/components/annotation.py index 262dff0139..de89952821 100644 --- a/src/datumaro/components/annotation.py +++ b/src/datumaro/components/annotation.py @@ -4,6 +4,7 @@ from __future__ import annotations +import math from enum import IntEnum from functools import partial from itertools import zip_longest @@ -852,8 +853,7 @@ class RotatedBbox(_Shape): def __init__(self, x, y, w, h, r, *args, **kwargs): kwargs.pop("points", None) # comes from wrap() - # points = x1, y1, x2, y2, x3, y3, x4, y4 - self.__attrs_init__([x, y, x + w, y + h, r], *args, **kwargs) + self.__attrs_init__([x, y, w, h, r], *args, **kwargs) @property def x(self): @@ -865,11 +865,11 @@ def y(self): @property def w(self): - return self.points[2] - self.points[0] + return self.points[2] @property def h(self): - return self.points[3] - self.points[1] + return self.points[3] @property def r(self): @@ -879,10 +879,40 @@ def get_area(self): return self.w * self.h def get_bbox(self): - return [self.x, self.y, self.w, self.h] + points = self.as_polygon() + xs = [p for p in points[0::2]] + ys = [p for p in points[1::2]] + + return [min(xs), min(ys), max(xs) - min(xs), max(ys) - min(ys)] + + def get_rotated_bbox(self): + return [self.x, self.y, self.w, self.h, self.r] def as_polygon(self) -> List[float]: - return self.points + """Convert [center_x, center_y, width, height, rotation] to 8 coordinates for a rotated bounding box.""" + + half_width = self.w / 2 + half_height = self.h / 2 + rot = np.deg2rad(self.r) + + # Calculate coordinates of the four corners + corners = np.array( + [ + [-half_width, -half_height], + [half_width, -half_height], + [half_width, half_height], + [-half_width, half_height], + ] + ) + + # Rotate the corners + transformed = [] + for corner in corners: + x = corner[0] * math.cos(rot) - corner[1] * math.sin(rot) + self.x + y = corner[0] * math.sin(rot) + corner[1] * math.cos(rot) + self.y + transformed.extend([x, y]) + + return transformed def iou(self, other: _Shape) -> Union[float, Literal[-1]]: from datumaro.util.annotation_util import bbox_iou diff --git a/src/datumaro/components/annotations/matcher.py b/src/datumaro/components/annotations/matcher.py index 40af5bb2ff..5e59456bfd 100644 --- a/src/datumaro/components/annotations/matcher.py +++ b/src/datumaro/components/annotations/matcher.py @@ -9,7 +9,7 @@ from datumaro.components.abstracts import IMergerContext from datumaro.components.abstracts.merger import IMatcherContext -from datumaro.components.annotation import Annotation +from datumaro.components.annotation import Annotation, Points from datumaro.util.annotation_util import ( OKS, approximate_line, @@ -371,5 +371,7 @@ def match_annotations(self, sources): @attrs class RotatedBboxMatcher(ShapeMatcher): + sigma: Optional[list] = attrib(default=None) + def distance(self, a, b): - return OKS(a, b, sigma=self.sigma) + return OKS(Points(a.as_polygon()), Points(b.as_polygon()), sigma=self.sigma) diff --git a/src/datumaro/components/task.py b/src/datumaro/components/task.py index 5d02e59ac8..0adc74dc00 100644 --- a/src/datumaro/components/task.py +++ b/src/datumaro/components/task.py @@ -42,8 +42,7 @@ def __init__(self): AnnotationType.points, }, TaskType.detection_rotated: { - AnnotationType.label, - AnnotationType.polygon, + AnnotationType.rotated_bbox, }, TaskType.detection_3d: {AnnotationType.label, AnnotationType.cuboid_3d}, TaskType.segmentation_semantic: { @@ -53,6 +52,7 @@ def __init__(self): TaskType.segmentation_instance: { AnnotationType.label, AnnotationType.bbox, + AnnotationType.rotated_bbox, AnnotationType.ellipse, AnnotationType.polygon, AnnotationType.points, @@ -65,6 +65,7 @@ def __init__(self): TaskType.mixed: { AnnotationType.label, AnnotationType.bbox, + AnnotationType.rotated_bbox, AnnotationType.cuboid_3d, AnnotationType.ellipse, AnnotationType.polygon, diff --git a/tests/unit/test_ops.py b/tests/unit/test_ops.py index 45d97ab44e..53266f78b4 100644 --- a/tests/unit/test_ops.py +++ b/tests/unit/test_ops.py @@ -17,6 +17,7 @@ PointsCategories, Polygon, PolyLine, + RotatedBbox, ) from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DEFAULT_SUBSET_NAME, DatasetItem @@ -222,6 +223,7 @@ def test_can_match_shapes(self): Points([5, 6], label=0, group=1), Points([6, 8], label=1, group=1), PolyLine([1, 1, 2, 1, 3, 1]), + RotatedBbox(4, 5, 2, 4, 20, label=2), ], ), ], @@ -251,6 +253,7 @@ def test_can_match_shapes(self): Points([5.5, 6.5], label=0, group=2), Points([6, 8], label=1, group=2), PolyLine([1, 1.5, 2, 1.5]), + RotatedBbox(2, 4, 2, 4, 10, label=1), ], ), ], @@ -280,6 +283,7 @@ def test_can_match_shapes(self): Bbox(3, 6, 2, 3, label=2, z_order=4, group=3), Points([4.5, 5.5], label=0, group=3), PolyLine([1, 1.25, 3, 1, 4, 2]), + RotatedBbox(2, 4, 2, 4, 10, label=1), ], ), ], @@ -313,6 +317,8 @@ def test_can_match_shapes(self): Points([5, 6], label=0, group=1), Points([6, 8], label=1, group=1), PolyLine([1, 1.25, 3, 1, 4, 2]), + RotatedBbox(4, 5, 2, 4, 20, label=2), + RotatedBbox(2, 4, 2, 4, 10, label=1), ], ), ], @@ -330,11 +336,21 @@ def test_can_match_shapes(self): sources={2}, ann=source0.get("1").annotations[5], ), + NoMatchingAnnError( + item_id=("1", DEFAULT_SUBSET_NAME), + sources={0}, + ann=source1.get("1").annotations[6], + ), NoMatchingAnnError( item_id=("1", DEFAULT_SUBSET_NAME), sources={1, 2}, ann=source0.get("1").annotations[0], ), + NoMatchingAnnError( + item_id=("1", DEFAULT_SUBSET_NAME), + sources={1, 2}, + ann=source0.get("1").annotations[7], + ), ], sorted( (e for e in merger.errors if isinstance(e, NoMatchingAnnError)), From 8b334fdda144596da00663a03df0ef7b99022fde Mon Sep 17 00:00:00 2001 From: Wonju Lee Date: Fri, 19 Apr 2024 18:48:11 +0900 Subject: [PATCH 4/8] fix unit tests --- tests/unit/operations/test_statistics.py | 46 ++++++++++++------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/unit/operations/test_statistics.py b/tests/unit/operations/test_statistics.py index b12ea16b6b..bb92c53308 100644 --- a/tests/unit/operations/test_statistics.py +++ b/tests/unit/operations/test_statistics.py @@ -10,7 +10,7 @@ import numpy as np import pytest -from datumaro.components.annotation import Bbox, Caption, Ellipse, Label, Mask, Points +from datumaro.components.annotation import Bbox, Caption, Ellipse, Label, Mask, Points, RotatedBbox from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.errors import DatumaroError @@ -222,6 +222,16 @@ def test_stats(self): "occluded": False, }, ), + RotatedBbox( + 4, + 4, + 2, + 2, + 20, + attributes={ + "tiny": True, + }, + ), ], ), DatasetItem(id=3), @@ -232,7 +242,7 @@ def test_stats(self): expected = { "images count": 4, - "annotations count": 11, + "annotations count": 12, "unannotated images count": 2, "unannotated images": ["3", "2.2"], "annotations by type": { @@ -257,6 +267,9 @@ def test_stats(self): "caption": { "count": 2, }, + "rotated_bbox": { + "count": 1, + }, "cuboid_3d": {"count": 0}, "super_resolution_annotation": {"count": 0}, "depth_annotation": {"count": 0}, @@ -375,27 +388,13 @@ def _get_stats_template(label_names: list): "unannotated images count": 0, "unannotated images": [], "annotations by type": { - "label": { - "count": 0, - }, - "polygon": { - "count": 0, - }, - "polyline": { - "count": 0, - }, - "bbox": { - "count": 0, - }, - "mask": { - "count": 0, - }, - "points": { - "count": 0, - }, - "caption": { - "count": 0, - }, + "label": {"count": 0}, + "polygon": {"count": 0}, + "polyline": {"count": 0}, + "bbox": {"count": 0}, + "mask": {"count": 0}, + "points": {"count": 0}, + "caption": {"count": 0}, "cuboid_3d": {"count": 0}, "super_resolution_annotation": {"count": 0}, "depth_annotation": {"count": 0}, @@ -403,6 +402,7 @@ def _get_stats_template(label_names: list): "hash_key": {"count": 0}, "feature_vector": {"count": 0}, "tabular": {"count": 0}, + "rotated_bbox": {"count": 0}, "unknown": {"count": 0}, }, "annotations": { From 9d6e3b70f982d8aacbc1df2ab5daf1eec05399dc Mon Sep 17 00:00:00 2001 From: Wonju Lee Date: Fri, 19 Apr 2024 21:01:42 +0900 Subject: [PATCH 5/8] add polygon conversions --- src/datumaro/components/annotation.py | 75 ++++++++++++++++----------- src/datumaro/util/annotation_util.py | 33 +----------- tests/unit/test_annotation.py | 17 +++++- 3 files changed, 63 insertions(+), 62 deletions(-) diff --git a/src/datumaro/components/annotation.py b/src/datumaro/components/annotation.py index de89952821..9c2a1e81b4 100644 --- a/src/datumaro/components/annotation.py +++ b/src/datumaro/components/annotation.py @@ -851,16 +851,33 @@ def wrap(item, **kwargs): class RotatedBbox(_Shape): _type = AnnotationType.rotated_bbox - def __init__(self, x, y, w, h, r, *args, **kwargs): + def __init__(self, cx, cy, w, h, r, *args, **kwargs): kwargs.pop("points", None) # comes from wrap() - self.__attrs_init__([x, y, w, h, r], *args, **kwargs) + self.__attrs_init__([cx, cy, w, h, r], *args, **kwargs) + + @classmethod + def from_polygon(cls, points: List[Tuple[float, float]]): + assert len(points) == 4, "polygon for a rotated bbox should have only 4 coordinates." + + # Calculate rotation angle + rot = math.atan2(points[1][1] - points[0][1], points[1][0] - points[0][0]) + + # Calculate the center of the bounding box + cx = (points[0][0] + points[2][0]) / 2 + cy = (points[0][1] + points[2][1]) / 2 + + # Calculate the width and height + width = math.sqrt((points[1][0] - points[0][0]) ** 2 + (points[1][1] - points[0][1]) ** 2) + height = math.sqrt((points[2][0] - points[1][0]) ** 2 + (points[2][1] - points[1][1]) ** 2) + + return cls(cx=cx, cy=cy, w=width, h=height, r=math.degrees(rot)) @property - def x(self): + def cx(self): return self.points[0] @property - def y(self): + def cy(self): return self.points[1] @property @@ -886,33 +903,33 @@ def get_bbox(self): return [min(xs), min(ys), max(xs) - min(xs), max(ys) - min(ys)] def get_rotated_bbox(self): - return [self.x, self.y, self.w, self.h, self.r] - - def as_polygon(self) -> List[float]: - """Convert [center_x, center_y, width, height, rotation] to 8 coordinates for a rotated bounding box.""" - - half_width = self.w / 2 - half_height = self.h / 2 - rot = np.deg2rad(self.r) - - # Calculate coordinates of the four corners - corners = np.array( - [ - [-half_width, -half_height], - [half_width, -half_height], - [half_width, half_height], - [-half_width, half_height], - ] - ) + return [self.cx, self.cy, self.w, self.h, self.r] + + def as_polygon(self) -> List[Tuple[float, float]]: + """Convert [center_x, center_y, width, height, rotation] to 4 coordinates for a rotated bounding box.""" + + def _rotate_point(x, y, angle): + """Rotate a point around another point.""" + angle_rad = math.radians(angle) + cos_theta = math.cos(angle_rad) + sin_theta = math.sin(angle_rad) + nx = cos_theta * x - sin_theta * y + ny = sin_theta * x + cos_theta * y + return nx, ny + + # Calculate corner points of the rectangle + corners = [ + (-self.w / 2, -self.h / 2), + (self.w / 2, -self.h / 2), + (self.w / 2, self.h / 2), + (-self.w / 2, self.h / 2), + ] - # Rotate the corners - transformed = [] - for corner in corners: - x = corner[0] * math.cos(rot) - corner[1] * math.sin(rot) + self.x - y = corner[0] * math.sin(rot) + corner[1] * math.cos(rot) + self.y - transformed.extend([x, y]) + # Rotate each corner point + rotated_corners = [_rotate_point(p[0], p[1], self.r) for p in corners] - return transformed + # Translate the rotated points to the original position + return [(p[0] + self.cx, p[1] + self.cy) for p in rotated_corners] def iou(self, other: _Shape) -> Union[float, Literal[-1]]: from datumaro.util.annotation_util import bbox_iou diff --git a/src/datumaro/util/annotation_util.py b/src/datumaro/util/annotation_util.py index 3236c9f12e..ff371d5a55 100644 --- a/src/datumaro/util/annotation_util.py +++ b/src/datumaro/util/annotation_util.py @@ -2,21 +2,13 @@ # # SPDX-License-Identifier: MIT -import math from itertools import groupby from typing import Callable, Dict, Iterable, NewType, Optional, Sequence, Tuple, Union import numpy as np from typing_extensions import Literal -from datumaro.components.annotation import ( - AnnotationType, - LabelCategories, - Mask, - Points, - RleMask, - _Shape, -) +from datumaro.components.annotation import AnnotationType, LabelCategories, Mask, RleMask, _Shape from datumaro.util.mask_tools import mask_to_rle @@ -297,26 +289,3 @@ def map_id(src_id): return id_mapping.get(src_id, fallback) return map_id, id_mapping, source_labels, target_labels - - -def points_to_rotated_bbox(points: Points): - """Convert 8 points representing a rotated bounding box to [top_left_x, top_left_y, width, height, rotation].""" - # Extract individual coordinates from the flat list - x1, y1, x2, y2, x3, y3, _, _ = points # [x1, y1, x1 + w, y1, x1 + w, y1 + h, x1, y1 + h] - - # Calculate rotation angle - angle = math.atan2(y2 - y1, x2 - x1) - - # Calculate the center of the bounding box - center_x = (x1 + x3) / 2 - center_y = (y1 + y3) / 2 - - # Calculate width and height - width = (x3 - x1) / math.cos(angle) - height = (y3 - y1) / math.sin(angle) - - # Calculate top-left corner coordinates - top_left_x = center_x - width / 2 - top_left_y = center_y - height / 2 - - return [top_left_x, top_left_y, width, height, angle] diff --git a/tests/unit/test_annotation.py b/tests/unit/test_annotation.py index c9679e2e55..38d193d1ca 100644 --- a/tests/unit/test_annotation.py +++ b/tests/unit/test_annotation.py @@ -8,7 +8,7 @@ import pytest import shapely.geometry as sg -from datumaro.components.annotation import Ellipse, HashKey +from datumaro.components.annotation import Ellipse, HashKey, RotatedBbox class EllipseTest: @@ -52,3 +52,18 @@ def fxt_hashkeys_diff(self): def test_compare_hashkey(self, fxt_hashkeys, expected, request): hashkey1, hashkey2 = request.getfixturevalue(fxt_hashkeys) assert (expected, hashkey1 == hashkey2) + + +class RotatedBboxTest: + @pytest.fixture + def fxt_rot_bbox(self): + coords = np.random.randint(0, 180, size=(5,), dtype=np.uint8) + return RotatedBbox(coords[0], coords[1], coords[2], coords[3], coords[4]) + + @pytest.mark.parametrize("fxt_ann", ["fxt_rot_bbox"]) + def test_create_polygon(self, fxt_ann, request): + fxt_rot_bbox = request.getfixturevalue(fxt_ann) + polygon = fxt_rot_bbox.as_polygon() + + expected = RotatedBbox.from_polygon(polygon) + assert fxt_rot_bbox == expected From 671603c7e1f1f4d53e2e29f7f7506204656ad389 Mon Sep 17 00:00:00 2001 From: Wonju Lee Date: Fri, 19 Apr 2024 22:27:47 +0900 Subject: [PATCH 6/8] update yolo obb --- src/datumaro/components/annotation.py | 10 +++--- .../components/annotations/matcher.py | 5 ++- .../plugins/data_formats/roboflow/base.py | 6 ++-- tests/unit/data_formats/test_roboflow.py | 34 ++++++++++++++----- 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/src/datumaro/components/annotation.py b/src/datumaro/components/annotation.py index 9c2a1e81b4..a9edd00037 100644 --- a/src/datumaro/components/annotation.py +++ b/src/datumaro/components/annotation.py @@ -856,7 +856,7 @@ def __init__(self, cx, cy, w, h, r, *args, **kwargs): self.__attrs_init__([cx, cy, w, h, r], *args, **kwargs) @classmethod - def from_polygon(cls, points: List[Tuple[float, float]]): + def from_polygon(cls, points: List[Tuple[float, float]], *args, **kwargs): assert len(points) == 4, "polygon for a rotated bbox should have only 4 coordinates." # Calculate rotation angle @@ -870,7 +870,7 @@ def from_polygon(cls, points: List[Tuple[float, float]]): width = math.sqrt((points[1][0] - points[0][0]) ** 2 + (points[1][1] - points[0][1]) ** 2) height = math.sqrt((points[2][0] - points[1][0]) ** 2 + (points[2][1] - points[1][1]) ** 2) - return cls(cx=cx, cy=cy, w=width, h=height, r=math.degrees(rot)) + return cls(cx=cx, cy=cy, w=width, h=height, r=math.degrees(rot), *args, **kwargs) @property def cx(self): @@ -896,9 +896,9 @@ def get_area(self): return self.w * self.h def get_bbox(self): - points = self.as_polygon() - xs = [p for p in points[0::2]] - ys = [p for p in points[1::2]] + polygon = self.as_polygon() + xs = [pt[0] for pt in polygon] + ys = [pt[1] for pt in polygon] return [min(xs), min(ys), max(xs) - min(xs), max(ys) - min(ys)] diff --git a/src/datumaro/components/annotations/matcher.py b/src/datumaro/components/annotations/matcher.py index 5e59456bfd..db9322722a 100644 --- a/src/datumaro/components/annotations/matcher.py +++ b/src/datumaro/components/annotations/matcher.py @@ -374,4 +374,7 @@ class RotatedBboxMatcher(ShapeMatcher): sigma: Optional[list] = attrib(default=None) def distance(self, a, b): - return OKS(Points(a.as_polygon()), Points(b.as_polygon()), sigma=self.sigma) + a = Points([p for pt in a.as_polygon() for p in pt]) + b = Points([p for pt in b.as_polygon() for p in pt]) + + return OKS(a, b, sigma=self.sigma) diff --git a/src/datumaro/plugins/data_formats/roboflow/base.py b/src/datumaro/plugins/data_formats/roboflow/base.py index 96c1a622ec..5070821dc7 100644 --- a/src/datumaro/plugins/data_formats/roboflow/base.py +++ b/src/datumaro/plugins/data_formats/roboflow/base.py @@ -16,7 +16,7 @@ Bbox, Label, LabelCategories, - Polygon, + RotatedBbox, ) from datumaro.components.dataset import DatasetItem from datumaro.components.dataset_base import SubsetBase @@ -171,8 +171,8 @@ def _parse_annotations( x4 = self._parse_field(x4, float, "x4") y4 = self._parse_field(y4, float, "y4") annotations.append( - Polygon( - points=[x1, y1, x2, y2, x3, y3, x4, y4], + RotatedBbox.from_polygon( + points=[(x1, y1), (x2, y2), (x3, y3), (x4, y4)], label=label_id, id=idx, group=idx, diff --git a/tests/unit/data_formats/test_roboflow.py b/tests/unit/data_formats/test_roboflow.py index c0ddadade9..4db669c01a 100644 --- a/tests/unit/data_formats/test_roboflow.py +++ b/tests/unit/data_formats/test_roboflow.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from datumaro.components.annotation import Bbox, Label, Polygon +from datumaro.components.annotation import Bbox, Label, Polygon, RotatedBbox from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.environment import DEFAULT_ENVIRONMENT @@ -174,8 +174,12 @@ def fxt_yolo_obb_dataset(): subset="train", media=Image.from_numpy(data=np.ones((5, 10, 3))), annotations=[ - Polygon( - points=[0, 0, 0, 2, 2, 2, 2, 0], + RotatedBbox( + 1, + 1, + 2, + 2, + 90, label=0, group=0, id=0, @@ -187,8 +191,12 @@ def fxt_yolo_obb_dataset(): subset="train", media=Image.from_numpy(data=np.ones((5, 10, 3))), annotations=[ - Polygon( - points=[1, 1, 1, 5, 5, 5, 5, 1], + RotatedBbox( + 3, + 3, + 4, + 4, + 90, label=1, group=0, id=0, @@ -200,14 +208,22 @@ def fxt_yolo_obb_dataset(): subset="val", media=Image.from_numpy(data=np.ones((5, 10, 3))), annotations=[ - Polygon( - points=[0, 0, 0, 1, 1, 1, 1, 0], + RotatedBbox( + 0.5, + 0.5, + 1, + 1, + 90, label=0, group=0, id=0, ), - Polygon( - points=[1, 2, 1, 5, 10, 5, 10, 2], + RotatedBbox( + 5.5, + 3.5, + 3, + 9, + 90, label=1, group=1, id=1, From 2bbfbaeee81b4c8f89a67ccc937dce9c7c7d4aef Mon Sep 17 00:00:00 2001 From: Wonju Lee Date: Mon, 22 Apr 2024 22:30:44 +0900 Subject: [PATCH 7/8] from_rectangle --- src/datumaro/components/annotation.py | 2 +- src/datumaro/plugins/data_formats/roboflow/base.py | 2 +- tests/unit/data_formats/test_roboflow.py | 2 +- tests/unit/test_annotation.py | 6 ++---- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/datumaro/components/annotation.py b/src/datumaro/components/annotation.py index a9edd00037..89d543d9d1 100644 --- a/src/datumaro/components/annotation.py +++ b/src/datumaro/components/annotation.py @@ -856,7 +856,7 @@ def __init__(self, cx, cy, w, h, r, *args, **kwargs): self.__attrs_init__([cx, cy, w, h, r], *args, **kwargs) @classmethod - def from_polygon(cls, points: List[Tuple[float, float]], *args, **kwargs): + def from_rectangle(cls, points: List[Tuple[float, float]], *args, **kwargs): assert len(points) == 4, "polygon for a rotated bbox should have only 4 coordinates." # Calculate rotation angle diff --git a/src/datumaro/plugins/data_formats/roboflow/base.py b/src/datumaro/plugins/data_formats/roboflow/base.py index 5070821dc7..cb5c420897 100644 --- a/src/datumaro/plugins/data_formats/roboflow/base.py +++ b/src/datumaro/plugins/data_formats/roboflow/base.py @@ -171,7 +171,7 @@ def _parse_annotations( x4 = self._parse_field(x4, float, "x4") y4 = self._parse_field(y4, float, "y4") annotations.append( - RotatedBbox.from_polygon( + RotatedBbox.from_rectangle( points=[(x1, y1), (x2, y2), (x3, y3), (x4, y4)], label=label_id, id=idx, diff --git a/tests/unit/data_formats/test_roboflow.py b/tests/unit/data_formats/test_roboflow.py index 4db669c01a..8391f1a293 100644 --- a/tests/unit/data_formats/test_roboflow.py +++ b/tests/unit/data_formats/test_roboflow.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from datumaro.components.annotation import Bbox, Label, Polygon, RotatedBbox +from datumaro.components.annotation import Bbox, Label, RotatedBbox from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.environment import DEFAULT_ENVIRONMENT diff --git a/tests/unit/test_annotation.py b/tests/unit/test_annotation.py index 38d193d1ca..08508d172b 100644 --- a/tests/unit/test_annotation.py +++ b/tests/unit/test_annotation.py @@ -60,10 +60,8 @@ def fxt_rot_bbox(self): coords = np.random.randint(0, 180, size=(5,), dtype=np.uint8) return RotatedBbox(coords[0], coords[1], coords[2], coords[3], coords[4]) - @pytest.mark.parametrize("fxt_ann", ["fxt_rot_bbox"]) - def test_create_polygon(self, fxt_ann, request): - fxt_rot_bbox = request.getfixturevalue(fxt_ann) + def test_create_polygon(self, fxt_rot_bbox): polygon = fxt_rot_bbox.as_polygon() - expected = RotatedBbox.from_polygon(polygon) + expected = RotatedBbox.from_rectangle(polygon) assert fxt_rot_bbox == expected From 6f79643378388b1388e29bdf2f35fedb045dbf44 Mon Sep 17 00:00:00 2001 From: Wonju Lee Date: Mon, 22 Apr 2024 22:33:05 +0900 Subject: [PATCH 8/8] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b875950731..212354a3b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features - Add task_type property for dataset () +- Add AnnotationType.rotated_bbox for oriented object detection + () ### Enhancements - Fix ambiguous COCO format detector