Skip to content

Commit

Permalink
Make sure preds are compatible with vectorization
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Sep 30, 2024
1 parent 86cd61e commit f2199c8
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion kraken/lib/train/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,13 @@ def validation_step(self, batch, batch_idx):
st_sep = self.nn.user_metadata['class_mapping']['aux']['_start_separator']
end_sep = self.nn.user_metadata['class_mapping']['aux']['_end_separator']

# cast pred/targets to float32 and move to CPU
pred = pred.cpu().float()
y_curves = y_curves.cpu()

# vectorize and match lines
for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items():
pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].cpu().numpy(), text_direction='horizontal')
pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].numpy(), text_direction='horizontal')
pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl]
if line_cls in y_curves:
target_curves = y_curves[line_cls][0]
Expand Down

0 comments on commit f2199c8

Please sign in to comment.