Skip to content

Commit

Permalink
Cast output to float64 in inference
Browse files Browse the repository at this point in the history
Otherwise numpy conversion fails when using 16 bit precision
  • Loading branch information
mittagessen committed Mar 29, 2024
1 parent 857fb9c commit 4500256
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion kraken/lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def forward(self, line: torch.Tensor, lens: torch.Tensor = None) -> Union[np.nda
o, olens = self.nn.nn(line, lens)
if o.size(2) != 1:
raise KrakenInputException('Expected dimension 3 to be 1, actual {}'.format(o.size()))
self.outputs = o.detach().squeeze(2).cpu().numpy()
self.outputs = o.detach().squeeze(2).float().cpu().numpy()
if olens is not None:
olens = olens.cpu().numpy()
return self.outputs, olens
Expand Down

0 comments on commit 4500256

Please sign in to comment.