Skip to content

Commit

Permalink
fix some bug. (#2825)
Browse files Browse the repository at this point in the history
  • Loading branch information
zxcd authored Jan 12, 2023
1 parent faa2f86 commit ad40daf
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
4 changes: 4 additions & 0 deletions paddlespeech/s2t/models/whisper/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def decode(self,
if ids < len(self.tokenizer):
ids_list.append(ids)
token_ids = ids_list
elif len(token_ids) == 1:
token_ids = token_ids[0]
else:
raise ValueError(f"token_ids {token_ids} load error.")

return self.tokenizer.decode(token_ids, **kwargs)

Expand Down
14 changes: 7 additions & 7 deletions paddlespeech/s2t/models/whisper/whipser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
import numpy as np
import paddle
import paddle.nn.functional as F
import paddlespeech.s2t.modules.align as paddlespeech_nn
import soundfile
import tqdm
from paddle import nn
from paddle.distribution import Categorical

import paddlespeech.s2t.modules.align as paddlespeech_nn
from paddlespeech.s2t.models.whisper import utils
from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer
from paddlespeech.s2t.models.whisper.tokenizer import LANGUAGES
Expand Down Expand Up @@ -771,8 +770,10 @@ def update(self,
if temperature == 0:
next_tokens = paddle.argmax(logits, axis=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample(
shape=logits.shape)
next_tokens = Categorical(logits=logits / temperature).sample([1])
next_tokens = paddle.reshape(next_tokens, [
next_tokens.shape[0] * next_tokens.shape[1],
])

logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
current_logprobs = logprobs[paddle.arange(logprobs.shape[0]),
Expand Down Expand Up @@ -1205,9 +1206,8 @@ def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
DecodingResult(
audio_features=features,
language=language,
language_probs=probs)
for features, language, probs in zip(audio_features, languages,
language_probs)
language_probs=probs) for features, language, probs in
zip(audio_features, languages, language_probs)
]

# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
Expand Down

0 comments on commit ad40daf

Please sign in to comment.