Skip to content

Commit

Permalink
put costs on correct device
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Sep 30, 2024
1 parent 1b577c2 commit cda9b18
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions kraken/lib/train/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,11 @@ def validation_step(self, batch, batch_idx):
costs = np.sort(costs)[:len(y_curves[line_cls][0])]
penalty = np.full(diff, 8.0)
costs = np.concatenate([costs, penalty])
self.val_line_mean_dist.update(costs/8.0)
costs = costs/8.0
# no line output
else:
self.val_line_mean_dist.update(torch.ones(len(y_curves[line_cls][0])))
costs = torch.ones(len(y_curves[line_cls][0]))
self.val_line_mean_dist.update(costs.to(self.val_line_mean_dist.device))

def on_validation_epoch_end(self):
if not self.trainer.sanity_checking:
Expand Down

0 comments on commit cda9b18

Please sign in to comment.