diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index a6b5d2a0..b68fd79a 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -279,10 +279,11 @@ def validation_step(self, batch, batch_idx): costs = np.sort(costs)[:len(y_curves[line_cls][0])] penalty = np.full(diff, 8.0) costs = np.concatenate([costs, penalty]) - self.val_line_mean_dist.update(costs/8.0) + costs = costs/8.0 # no line output else: - self.val_line_mean_dist.update(torch.ones(len(y_curves[line_cls][0]))) + costs = torch.ones(len(y_curves[line_cls][0])) + self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device)) def on_validation_epoch_end(self): if not self.trainer.sanity_checking: