diff --git a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py index fcf2008e795..ec760df7f50 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py +++ b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_cls_head.py @@ -42,6 +42,10 @@ def forward(self, x): def forward_train(self, cls_score, gt_label): """Forward_train fuction of CustomNonLinearHead class.""" + bs = cls_score.shape[0] + if bs == 1: + cls_score = torch.cat([cls_score, cls_score], dim=0) + gt_label = torch.cat([gt_label, gt_label], dim=0) logit = self.classifier(cls_score) losses = self.loss(logit, gt_label, feature=cls_score) return losses