diff --git a/kraken/lib/train/segmentation.py b/kraken/lib/train/segmentation.py index 39b86b136..4dfa610ec 100644 --- a/kraken/lib/train/segmentation.py +++ b/kraken/lib/train/segmentation.py @@ -267,7 +267,8 @@ def validation_step(self, batch, batch_idx): # cast pred/targets to float32 and move to CPU pred = pred.cpu().float() - y_curves = y_curves.cpu() + for k, v in y_curves.items(): + y_curves[k] = v.cpu() # vectorize and match lines for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items():