Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Jun 22, 2023
1 parent 6ce6d62 commit 7b28857
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions tests/models/whisper/test_modeling_tf_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,10 @@ def _test_large_generation(in_queue, out_queue, timeout):
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features

generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
input_features,
do_sample=False,
max_length=20,
generation_kwargs={"language": "<|en|>", "task": "transcribe"},
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

Expand Down Expand Up @@ -733,23 +736,32 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features

generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
input_features,
do_sample=False,
max_length=20,
generation_kwargs={"language": "<|ja|>", "task": "transcribe"},
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)

generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
input_features,
do_sample=False,
max_length=20,
generation_kwargs={"language": "<|en|>", "task": "transcribe"},
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

EXPECTED_TRANSCRIPT = " Kimura-san called me."
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)

generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
input_features,
do_sample=False,
max_length=20,
generation_kwargs={"language": "<|ja|>", "task": "translate"},
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

Expand Down

0 comments on commit 7b28857

Please sign in to comment.