diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 6f071ab4..39b86b13 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -265,9 +265,13 @@ def validation_step(self, batch, batch_idx): st_sep = self.nn.user_metadata['class_mapping']['aux']['_start_separator'] end_sep = self.nn.user_metadata['class_mapping']['aux']['_end_separator'] + # cast pred/targets to float32 and move to CPU + pred = pred.cpu().float() + y_curves = y_curves.cpu() + # vectorize and match lines for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items(): - pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].cpu().numpy(), text_direction='horizontal') + pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].numpy(), text_direction='horizontal') pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl] if line_cls in y_curves: target_curves = y_curves[line_cls][0]