From 86cd61ef78eea01fa9b0754d5ad45138c3634aae Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 30 Sep 2024 12:12:00 +0200 Subject: [PATCH] deal with no pred no target case --- kraken/lib/train/segmentation.py | 35 ++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index b68fd79a..6f071ab4 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -269,21 +269,26 @@ def validation_step(self, batch, batch_idx): for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items(): pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].cpu().numpy(), text_direction='horizontal') pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl] - if pred_curves: - pred_curves = torch.stack(pred_curves) - cost_curves = torch.cdist(pred_curves, y_curves[line_cls][0], p=1).cpu() - costs = cost_curves = [linear_sum_assignment(cost_curves)] - # num of predictions differs from target -> take n best - # predictions and add error penalty term for the rest. - if diff := abs(len(pred_curves) - len(y_curves[line_cls][0])): - costs = np.sort(costs)[:len(y_curves[line_cls][0])] - penalty = np.full(diff, 8.0) - costs = np.concatenate([costs, penalty]) - costs = costs/8.0 - # no line output - else: - costs = torch.ones(len(y_curves[line_cls][0])) - self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device)) + if line_cls in y_curves: + target_curves = y_curves[line_cls][0] + if pred_curves: + pred_curves = torch.stack(pred_curves) + cost_curves = torch.cdist(pred_curves, target_curves, p=1).cpu() + costs = cost_curves = [linear_sum_assignment(cost_curves)] + # num of predictions differs from target -> take n best + # predictions and add error penalty term for the rest. + if diff := abs(len(pred_curves) - len(target_curves)): + costs = np.sort(costs)[:len(target_curves)] + penalty = np.full(diff, 8.0) + costs = np.concatenate([costs, penalty]) + costs = costs/8.0 + # no line output + else: + costs = torch.ones(len(target_curves)) + self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device)) + elif pred_curves: + costs = torch.ones(len(pred_curves)) + self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device)) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: