From b19f7c879e910cabbb678b8b64cbc6ad75fa6315 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Fri, 24 Mar 2023 14:16:29 +0100 Subject: [PATCH 1/6] fix whisper mro --- optimum/exporters/onnx/__main__.py | 2 + optimum/onnxruntime/modeling_seq2seq.py | 57 ++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index f5b856a7cf5..5340988c019 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -405,6 +405,8 @@ def main_export( # Saving the model config and preprocessor as this is needed sometimes. model.config.save_pretrained(output) + if hasattr(model, "generation_config"): + model.generation_config.save_pretrained(output) maybe_save_preprocessors(model_name_or_path, output) if task == "stable-diffusion": diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 26333153aed..8b9c607b224 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -18,7 +18,7 @@ import logging import shutil -from abc import ABC, abstractmethod +from abc import ABC, ABCMeta, abstractmethod from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -26,7 +26,13 @@ import numpy as np import torch from huggingface_hub import hf_hub_download -from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, GenerationConfig +from transformers import ( + AutoModelForSeq2SeqLM, + AutoModelForSpeechSeq2Seq, + AutoModelForVision2Seq, + GenerationConfig, + WhisperForConditionalGeneration, +) from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput @@ -1046,6 +1052,53 @@ def can_generate(self): """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" return True + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: "PretrainedConfig", + **kwargs, + ): + if "WhisperForConditionalGeneration" in config.architectures: + return _ORTModelForWhisper._from_pretrained(model_id, config, **kwargs) + else: + return super()._from_pretrained(model_id, config, **kwargs) + + +class MetaClassRemoveParentsAndReorder(ABCMeta): + def mro(cls): + """ + Avoids inheritting from PreTrainedModel, nn.Module, ModuleUtilsMixin, PushToHubMixin, + and put GenerationMixin at the end of the MRO + """ + top_inheritance_index = ORTModelForSpeechSeq2Seq.__mro__.index(GenerationMixin) + return ( + (cls,) + + ORTModelForSpeechSeq2Seq.__mro__[:top_inheritance_index] + + (WhisperForConditionalGeneration,) + + ORTModelForSpeechSeq2Seq.__mro__[top_inheritance_index:] + ) + + +class _ORTModelForWhisper( + ORTModelForSpeechSeq2Seq, WhisperForConditionalGeneration, metaclass=MetaClassRemoveParentsAndReorder +): + """ + Whisper implements its own generate() method + """ + + def generate(self, **kwargs): + return super().generate(**kwargs) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: "PretrainedConfig", + **kwargs, + ): + return super(ORTModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs) + class ORTModelForVision2Seq(ORTModelForConditionalGeneration, GenerationMixin): """ From 2ab20fb13ddf0630c1a57f9b1da2b942afedeacb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Fri, 24 Mar 2023 14:23:40 +0100 Subject: [PATCH 2/6] typo --- optimum/exporters/onnx/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 5340988c019..de1dbe48f36 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -405,7 +405,7 @@ def main_export( # Saving the model config and preprocessor as this is needed sometimes. model.config.save_pretrained(output) - if hasattr(model, "generation_config"): + if hasattr(model, "generation_config") and model.generation_config is not None: model.generation_config.save_pretrained(output) maybe_save_preprocessors(model_name_or_path, output) From 0324fa9bb2bc20ef6edec7ed59b22398b3338f5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Fri, 24 Mar 2023 17:07:21 +0100 Subject: [PATCH 3/6] add tests --- optimum/onnxruntime/modeling_seq2seq.py | 2 ++ tests/onnxruntime/test_modeling.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 8b9c607b224..351b5dd7e3d 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -575,6 +575,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]): dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copyfile(src_path, dst_path) + self.generation_config.save_pretrained(save_directory) + @classmethod def _from_pretrained( cls, diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 0c00c7e6649..4124c217717 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -3386,9 +3386,13 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str): data = self._generate_random_audio_data() features = processor.feature_extractor(data, return_tensors="pt") - outputs = model.generate(inputs=features["input_features"]) + outputs = model.generate(inputs=features["input_features"], return_timestamps=True) res = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertIsInstance(res[0], str) + self.assertTrue("chunks" in outputs) + + outputs = model.generate(inputs=features["input_features"], return_timestamps=False) + self.assertTrue("chunks" in outputs) gc.collect() @@ -3453,9 +3457,13 @@ def test_pipeline_speech_recognition(self, test_name: str, model_arch: str, use_ feature_extractor=processor.feature_extractor, ) data = self._generate_random_audio_data() - outputs = pipe(data) + outputs = pipe(data, return_timestamps=True) self.assertEqual(pipe.device, onnx_model.device) self.assertIsInstance(outputs["text"], str) + self.assertTrue("chunks" in outputs) + + outputs = pipe(data, return_timestamps=False) + self.assertTrue("chunks" not in outputs) gc.collect() From 795726ab91388a5e807c6b1a50b25b31222a816c Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 24 Mar 2023 18:14:34 +0100 Subject: [PATCH 4/6] Update optimum/onnxruntime/modeling_seq2seq.py Co-authored-by: Michael Benayoun --- optimum/onnxruntime/modeling_seq2seq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 351b5dd7e3d..ae0359d97d4 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -1086,7 +1086,7 @@ class _ORTModelForWhisper( ORTModelForSpeechSeq2Seq, WhisperForConditionalGeneration, metaclass=MetaClassRemoveParentsAndReorder ): """ - Whisper implements its own generate() method + Whisper implements its own generate() method. """ def generate(self, **kwargs): From 2c29ab365d27b93974a537d3e59ad8b3b60c904a Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 24 Mar 2023 18:14:52 +0100 Subject: [PATCH 5/6] Update optimum/exporters/onnx/__main__.py Co-authored-by: Michael Benayoun --- optimum/exporters/onnx/__main__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index de1dbe48f36..6c837399f6d 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -405,8 +405,9 @@ def main_export( # Saving the model config and preprocessor as this is needed sometimes. model.config.save_pretrained(output) - if hasattr(model, "generation_config") and model.generation_config is not None: - model.generation_config.save_pretrained(output) + generation_config = getattr(model, "generation_config") + if generation_config is not None: + generation_config.save_pretrained(output) maybe_save_preprocessors(model_name_or_path, output) if task == "stable-diffusion": From 0eb2ba676fb18a03665c776388c2828ca505a5dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Fri, 24 Mar 2023 18:28:12 +0100 Subject: [PATCH 6/6] fix test --- optimum/exporters/onnx/__main__.py | 2 +- optimum/onnxruntime/modeling_seq2seq.py | 3 --- tests/onnxruntime/test_modeling.py | 17 ++++++++--------- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 6c837399f6d..859cbc0af03 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -405,7 +405,7 @@ def main_export( # Saving the model config and preprocessor as this is needed sometimes. model.config.save_pretrained(output) - generation_config = getattr(model, "generation_config") + generation_config = getattr(model, "generation_config", None) if generation_config is not None: generation_config.save_pretrained(output) maybe_save_preprocessors(model_name_or_path, output) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index ae0359d97d4..bee6ec1c018 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -1089,9 +1089,6 @@ class _ORTModelForWhisper( Whisper implements its own generate() method. """ - def generate(self, **kwargs): - return super().generate(**kwargs) - @classmethod def _from_pretrained( cls, diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 4124c217717..080b3487533 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -3386,13 +3386,9 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str): data = self._generate_random_audio_data() features = processor.feature_extractor(data, return_tensors="pt") - outputs = model.generate(inputs=features["input_features"], return_timestamps=True) + outputs = model.generate(inputs=features["input_features"]) res = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertIsInstance(res[0], str) - self.assertTrue("chunks" in outputs) - - outputs = model.generate(inputs=features["input_features"], return_timestamps=False) - self.assertTrue("chunks" in outputs) gc.collect() @@ -3457,13 +3453,16 @@ def test_pipeline_speech_recognition(self, test_name: str, model_arch: str, use_ feature_extractor=processor.feature_extractor, ) data = self._generate_random_audio_data() - outputs = pipe(data, return_timestamps=True) + outputs = pipe(data) self.assertEqual(pipe.device, onnx_model.device) self.assertIsInstance(outputs["text"], str) - self.assertTrue("chunks" in outputs) - outputs = pipe(data, return_timestamps=False) - self.assertTrue("chunks" not in outputs) + if model_arch == "whisper": + outputs = pipe(data, return_timestamps=True) + self.assertTrue("chunks" in outputs) + + outputs = pipe(data, return_timestamps=False) + self.assertTrue("chunks" not in outputs) gc.collect()