Skip to content

Commit

Permalink
torch.sort works differently than np.sort
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Sep 30, 2024
1 parent 4d370d0 commit 4845ac0
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion kraken/lib/train/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def validation_step(self, batch, batch_idx):
# num of predictions differs from target -> take n best
# predictions and add error penalty term for the rest.
if diff := abs(len(pred_curves) - len(target_curves)):
costs = torch.sort(costs)[:len(target_curves)]
costs, _ = torch.sort(costs)
costs = costs[:len(target_curves)]
penalty = torch.full((diff,), 8.0)
costs = torch.cat([costs, penalty])
costs = costs/8.0
Expand Down

0 comments on commit 4845ac0

Please sign in to comment.