From 7b28857efe9c72f71a8cdc0c02bad713b8cd9f20 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 22 Jun 2023 17:27:28 +0200 Subject: [PATCH] fix --- .../whisper/test_modeling_tf_whisper.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 0783bd67bf43..e4ae3e7eeeca 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -704,7 +704,10 @@ def _test_large_generation(in_queue, out_queue, timeout): input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features generated_ids = model.generate( - input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe" + input_features, + do_sample=False, + max_length=20, + generation_kwargs={"language": "<|en|>", "task": "transcribe"}, ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -733,7 +736,10 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout): input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features generated_ids = model.generate( - input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe" + input_features, + do_sample=False, + max_length=20, + generation_kwargs={"language": "<|ja|>", "task": "transcribe"}, ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -741,7 +747,10 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout): unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) generated_ids = model.generate( - input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe" + input_features, + do_sample=False, + max_length=20, + generation_kwargs={"language": "<|en|>", "task": "transcribe"}, ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] @@ -749,7 +758,10 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout): unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) generated_ids = model.generate( - input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate" + input_features, + do_sample=False, + max_length=20, + generation_kwargs={"language": "<|ja|>", "task": "translate"}, ) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]