Skip to content

Commit

Permalink
deal with no pred no target case
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Sep 30, 2024
1 parent cda9b18 commit 86cd61e
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions kraken/lib/train/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,21 +269,26 @@ def validation_step(self, batch, batch_idx):
for line_cls, line_idx in self.nn.user_metadata['class_mapping']['baselines'].items():
pred_bl = vectorize_lines(pred[0, [st_sep, end_sep, line_idx], ...].cpu().numpy(), text_direction='horizontal')
pred_curves = [to_curve(bl, pred.shape[2:][::-1]) for bl in pred_bl]
if pred_curves:
pred_curves = torch.stack(pred_curves)
cost_curves = torch.cdist(pred_curves, y_curves[line_cls][0], p=1).cpu()
costs = cost_curves = [linear_sum_assignment(cost_curves)]
# num of predictions differs from target -> take n best
# predictions and add error penalty term for the rest.
if diff := abs(len(pred_curves) - len(y_curves[line_cls][0])):
costs = np.sort(costs)[:len(y_curves[line_cls][0])]
penalty = np.full(diff, 8.0)
costs = np.concatenate([costs, penalty])
costs = costs/8.0
# no line output
else:
costs = torch.ones(len(y_curves[line_cls][0]))
self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device))
if line_cls in y_curves:
target_curves = y_curves[line_cls][0]
if pred_curves:
pred_curves = torch.stack(pred_curves)
cost_curves = torch.cdist(pred_curves, target_curves, p=1).cpu()
costs = cost_curves = [linear_sum_assignment(cost_curves)]
# num of predictions differs from target -> take n best
# predictions and add error penalty term for the rest.
if diff := abs(len(pred_curves) - len(target_curves)):
costs = np.sort(costs)[:len(target_curves)]
penalty = np.full(diff, 8.0)
costs = np.concatenate([costs, penalty])
costs = costs/8.0
# no line output
else:
costs = torch.ones(len(target_curves))
self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device))
elif pred_curves:
costs = torch.ones(len(pred_curves))
self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device))

def on_validation_epoch_end(self):
if not self.trainer.sanity_checking:
Expand Down

0 comments on commit 86cd61e

Please sign in to comment.