From 96eec3acfdd221dff66aa8c5e98443530face0b9 Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Mon, 9 Sep 2024 15:33:07 +0100 Subject: [PATCH] update rtdetr --- src/otx/algo/detection/rtdetr.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/otx/algo/detection/rtdetr.py b/src/otx/algo/detection/rtdetr.py index 9f487c04be3..d28bf526719 100644 --- a/src/otx/algo/detection/rtdetr.py +++ b/src/otx/algo/detection/rtdetr.py @@ -75,13 +75,15 @@ def _customize_inputs( # prepare bboxes for the model for bb, ll in zip(entity.bboxes, entity.labels): # convert to cxcywh if needed - converted_bboxes = ( - box_convert(bb, in_fmt="xyxy", out_fmt="cxcywh") if bb.format == BoundingBoxFormat.XYXY else bb - ) - # normalize the bboxes - scaled_bboxes = converted_bboxes / torch.tensor(bb.canvas_size[::-1]).tile(2)[None].to( - converted_bboxes.device, - ) + scaled_bboxes = bb + if len(bb): + converted_bboxes = ( + box_convert(bb, in_fmt="xyxy", out_fmt="cxcywh") if bb.format == BoundingBoxFormat.XYXY else bb + ) + # normalize the bboxes + scaled_bboxes = converted_bboxes / torch.tensor(bb.canvas_size[::-1]).tile(2)[None].to( + converted_bboxes.device, + ) targets.append({"boxes": scaled_bboxes, "labels": ll}) return {