Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ORTModel MRO for whisper #919

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ def main_export(

# Saving the model config and preprocessor as this is needed sometimes.
model.config.save_pretrained(output)
if hasattr(model, "generation_config") and model.generation_config is not None:
model.generation_config.save_pretrained(output)
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
maybe_save_preprocessors(model_name_or_path, output)

if task == "stable-diffusion":
Expand Down
59 changes: 57 additions & 2 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@

import logging
import shutil
from abc import ABC, abstractmethod
from abc import ABC, ABCMeta, abstractmethod
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, GenerationConfig
from transformers import (
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
GenerationConfig,
WhisperForConditionalGeneration,
)
from transformers.file_utils import add_start_docstrings_to_model_forward
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput

Expand Down Expand Up @@ -569,6 +575,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(src_path, dst_path)

self.generation_config.save_pretrained(save_directory)

@classmethod
def _from_pretrained(
cls,
Expand Down Expand Up @@ -1046,6 +1054,53 @@ def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True

@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
**kwargs,
):
if "WhisperForConditionalGeneration" in config.architectures:
return _ORTModelForWhisper._from_pretrained(model_id, config, **kwargs)
else:
return super()._from_pretrained(model_id, config, **kwargs)


class MetaClassRemoveParentsAndReorder(ABCMeta):
def mro(cls):
"""
Avoids inheritting from PreTrainedModel, nn.Module, ModuleUtilsMixin, PushToHubMixin,
and put GenerationMixin at the end of the MRO
"""
top_inheritance_index = ORTModelForSpeechSeq2Seq.__mro__.index(GenerationMixin)
return (
(cls,)
+ ORTModelForSpeechSeq2Seq.__mro__[:top_inheritance_index]
+ (WhisperForConditionalGeneration,)
+ ORTModelForSpeechSeq2Seq.__mro__[top_inheritance_index:]
)


class _ORTModelForWhisper(
ORTModelForSpeechSeq2Seq, WhisperForConditionalGeneration, metaclass=MetaClassRemoveParentsAndReorder
):
"""
Whisper implements its own generate() method
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
"""

def generate(self, **kwargs):
return super().generate(**kwargs)
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
**kwargs,
):
return super(ORTModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs)


class ORTModelForVision2Seq(ORTModelForConditionalGeneration, GenerationMixin):
"""
Expand Down
12 changes: 10 additions & 2 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3386,9 +3386,13 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str):
data = self._generate_random_audio_data()
features = processor.feature_extractor(data, return_tensors="pt")

outputs = model.generate(inputs=features["input_features"])
outputs = model.generate(inputs=features["input_features"], return_timestamps=True)
res = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertIsInstance(res[0], str)
self.assertTrue("chunks" in outputs)

outputs = model.generate(inputs=features["input_features"], return_timestamps=False)
self.assertTrue("chunks" in outputs)

gc.collect()

Expand Down Expand Up @@ -3453,9 +3457,13 @@ def test_pipeline_speech_recognition(self, test_name: str, model_arch: str, use_
feature_extractor=processor.feature_extractor,
)
data = self._generate_random_audio_data()
outputs = pipe(data)
outputs = pipe(data, return_timestamps=True)
self.assertEqual(pipe.device, onnx_model.device)
self.assertIsInstance(outputs["text"], str)
self.assertTrue("chunks" in outputs)

outputs = pipe(data, return_timestamps=False)
self.assertTrue("chunks" not in outputs)

gc.collect()

Expand Down