diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 08c1692c3..6325e7b7c 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -271,8 +271,15 @@ def validation_step(self, batch, batch_idx): if pred_curves: pred_curves = torch.stack(pred_curves) cost_curves = torch.cdist(pred_curves, y_curves[line_cls][0], p=1).cpu() - row_ind, col_ind = linear_sum_assignment(cost_curves) - self.val_line_dist.update(cost_curves[row_ind, col_ind]/8.0) + costs = cost_curves = [linear_sum_assignment(cost_curves)] + # num of predictions differs from target -> take n best + # predictions and add error penalty term for the rest. + if diff := abs(len(pred_curves) - len(y_curves[line_cls][0])): + costs = np.sort(costs)[:len(y_curves[line_cls][0])] + penalty = np.full(diff, 8.0) + costs = np.concatenate([costs, penalty]) + self.val_line_dist.update(costs/8.0) + # no line output else: self.val_line_dist.update(torch.ones(len(y_curves[line_cls][0])))