Skip to content

Commit

Permalink
add cardinality penalty term to baseline val metric
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Sep 29, 2024
1 parent 30c41ec commit fee7f13
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions kraken/lib/train/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,15 @@ def validation_step(self, batch, batch_idx):
if pred_curves:
pred_curves = torch.stack(pred_curves)
cost_curves = torch.cdist(pred_curves, y_curves[line_cls][0], p=1).cpu()
row_ind, col_ind = linear_sum_assignment(cost_curves)
self.val_line_dist.update(cost_curves[row_ind, col_ind]/8.0)
costs = cost_curves = [linear_sum_assignment(cost_curves)]
# num of predictions differs from target -> take n best
# predictions and add error penalty term for the rest.
if diff := abs(len(pred_curves) - len(y_curves[line_cls][0])):
costs = np.sort(costs)[:len(y_curves[line_cls][0])]
penalty = np.full(diff, 8.0)
costs = np.concatenate([costs, penalty])
self.val_line_dist.update(costs/8.0)
# no line output
else:
self.val_line_dist.update(torch.ones(len(y_curves[line_cls][0])))

Expand Down

0 comments on commit fee7f13

Please sign in to comment.