From fee7f13062ced5bec3c8e61dd4e70581ee1c438b Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 29 Sep 2024 14:54:27 +0200 Subject: [PATCH] add cardinality penalty term to baseline val metric --- kraken/lib/train/segmentation.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 08c1692c..6325e7b7 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -271,8 +271,15 @@ def validation_step(self, batch, batch_idx): if pred_curves: pred_curves = torch.stack(pred_curves) cost_curves = torch.cdist(pred_curves, y_curves[line_cls][0], p=1).cpu() - row_ind, col_ind = linear_sum_assignment(cost_curves) - self.val_line_dist.update(cost_curves[row_ind, col_ind]/8.0) + costs = cost_curves = [linear_sum_assignment(cost_curves)] + # num of predictions differs from target -> take n best + # predictions and add error penalty term for the rest. + if diff := abs(len(pred_curves) - len(y_curves[line_cls][0])): + costs = np.sort(costs)[:len(y_curves[line_cls][0])] + penalty = np.full(diff, 8.0) + costs = np.concatenate([costs, penalty]) + self.val_line_dist.update(costs/8.0) + # no line output else: self.val_line_dist.update(torch.ones(len(y_curves[line_cls][0])))