Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Whisper export with beam search test cases (microsoft#17228)
Browse files Browse the repository at this point in the history
### Description
This PR adds test cases for the custom export of [Whisper with beam
search](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/whisper).



### Motivation and Context
This PR checks that Whisper can be exported and runs with parity.
kunal-vaishnavi authored and kleiti committed Mar 22, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 0b4d99d commit 7424120
Showing 3 changed files with 112 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -316,7 +316,6 @@ def export_onnx_models(
use_external_data_format=use_external_data_format,
per_channel=quantize_per_channel,
reduce_range=quantize_reduce_range,
optimize_model=False,
extra_options={"MatMulConstBOnly": True},
)
else:
@@ -374,6 +373,7 @@ def main(argv=None):
args.provider,
)

max_diff = 0
if args.chain_model:
logger.info("Chaining model ... :")
args.beam_model_output_dir = WhisperHelper.get_onnx_path(
@@ -418,6 +418,7 @@ def main(argv=None):
output_paths = [args.beam_model_output_dir]

logger.info(f"Done! Outputs: {output_paths}")
return max_diff


if __name__ == "__main__":
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@

import numpy as np
import torch
from datasets import load_dataset
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
@@ -270,6 +269,18 @@ def verify_onnx(
pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path)
config = WhisperConfig.from_pretrained(model_name_or_path)

# Try to import `datasets` pip package
try:
from datasets import load_dataset
except Exception as e:
logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True)
install_cmd = "pip install datasets"
logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.")
os.system(install_cmd)

from datasets import load_dataset # noqa: F811

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features

98 changes: 98 additions & 0 deletions onnxruntime/test/python/transformers/test_generation.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
# --------------------------------------------------------------------------

import os
import shutil
import unittest

import onnx
@@ -19,10 +20,12 @@
from benchmark_helper import Precision
from convert_generation import main as run
from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
from models.whisper.convert_to_onnx import main as run_whisper
else:
from onnxruntime.transformers.benchmark_helper import Precision
from onnxruntime.transformers.convert_generation import main as run
from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
from onnxruntime.transformers.models.whisper.convert_to_onnx import main as run_whisper


class TestBeamSearchGpt(unittest.TestCase):
@@ -281,5 +284,100 @@ def test_external_data(self):
)


class TestBeamSearchWhisper(unittest.TestCase):
"""Test BeamSearch for Whisper"""

def setUp(self):
self.model_name = "openai/whisper-tiny"
self.pytorch_folder = "cache_models"
self.onnx_folder = "onnx_models"
self.decoder_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_decoder.onnx")
self.encoder_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_encoder_decoder_init.onnx")
self.beam_search_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_beamsearch.onnx")
self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers()

self.base_arguments = [
"-m",
self.model_name,
"--output",
self.onnx_folder,
"--use_external_data_format",
]
self.fp32_cpu_arguments = [
"--precision",
"fp32",
"--optimize_onnx",
]
self.fp16_cuda_arguments = [
"--precision",
"fp16",
"--provider",
"cuda",
"--optimize_onnx",
"--use_gpu",
]
self.int8_cpu_arguments = [
"--precision",
"int8",
"--quantize_embedding_layer",
]

def tearDown(self):
pytorch_dir = os.path.join(".", self.pytorch_folder)
if os.path.exists(pytorch_dir):
shutil.rmtree(pytorch_dir)
onnx_dir = os.path.join(".", self.onnx_folder)
if os.path.exists(onnx_dir):
shutil.rmtree(onnx_dir)

def remove_onnx_files(self):
if os.path.exists(self.beam_search_onnx_path):
os.remove(self.beam_search_onnx_path)
os.remove(self.beam_search_onnx_path + ".data")

if os.path.exists(self.decoder_onnx_path):
os.remove(self.decoder_onnx_path)
os.remove(self.decoder_onnx_path + ".data")

if os.path.exists(self.encoder_onnx_path):
os.remove(self.encoder_onnx_path)
os.remove(self.encoder_onnx_path + ".data")

def run_export(self, arguments):
max_diff = run_whisper(arguments)
self.assertTrue(os.path.exists(self.beam_search_onnx_path), "Whisper model was not exported")
self.remove_onnx_files()
self.assertTrue(max_diff == 0, f"ORT and PyTorch have a parity mismatch of {max_diff}")

def run_configs(self, optional_arguments):
# FP32 CPU
arguments = self.base_arguments + self.fp32_cpu_arguments + optional_arguments
self.run_export(arguments)

if self.enable_cuda:
# FP16 CUDA
arguments = self.base_arguments + self.fp16_cuda_arguments + optional_arguments
self.run_export(arguments)

# INT8 CPU
arguments = self.base_arguments + self.int8_cpu_arguments + optional_arguments
self.run_export(arguments)

@pytest.mark.slow
def test_required_args(self):
optional_args = []
self.run_configs(optional_args)

@pytest.mark.slow
def test_forced_decoder_ids(self):
decoder_input_ids = ["--use_forced_decoder_ids"]
self.run_configs(decoder_input_ids)

@pytest.mark.slow
def test_logits_processor(self):
logits_processor = ["--use_logits_processor"]
self.run_configs(logits_processor)


if __name__ == "__main__":
unittest.main()

0 comments on commit 7424120

Please sign in to comment.