From ea5ead2f1f037020725a1a622984e32f00b441ed Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Wed, 4 Dec 2024 15:32:17 +0000 Subject: [PATCH 1/2] Ensure target class indices are of type long in loss calculations --- src/otx/algo/detection/losses/rtdetr_loss.py | 2 +- src/otx/algo/detection/utils/matchers/hungarian_matcher.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/otx/algo/detection/losses/rtdetr_loss.py b/src/otx/algo/detection/losses/rtdetr_loss.py index ef7cda4b2ad..f5783023cad 100644 --- a/src/otx/algo/detection/losses/rtdetr_loss.py +++ b/src/otx/algo/detection/losses/rtdetr_loss.py @@ -77,7 +77,7 @@ def loss_labels_vfl( src_logits = outputs["pred_logits"] target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) - target_classes[idx] = target_classes_o + target_classes[idx] = target_classes_o.long() target = nn.functional.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype) diff --git a/src/otx/algo/detection/utils/matchers/hungarian_matcher.py b/src/otx/algo/detection/utils/matchers/hungarian_matcher.py index f4a366234be..d8489838081 100644 --- a/src/otx/algo/detection/utils/matchers/hungarian_matcher.py +++ b/src/otx/algo/detection/utils/matchers/hungarian_matcher.py @@ -71,7 +71,7 @@ def forward( out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] # Also concat the target labels and boxes - tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_ids = torch.cat([v["labels"] for v in targets]).long() tgt_bbox = torch.cat([v["boxes"] for v in targets]) # Compute the classification cost. Contrary to the loss, we don't use the NLL, From f5cd1ec77c7381689ed58c2aecc0bd7cc600088d Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Wed, 4 Dec 2024 15:36:27 +0000 Subject: [PATCH 2/2] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ddf602a0f5..de59809eb13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -126,6 +126,8 @@ All notable changes to this project will be documented in this file. () - Fix empty annotation in tiling () +- Fix DETR target class indices are of type long in loss calculations + () ## \[v2.1.0\]