Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Wav2vec2 for Transformers optimizer (fusion) #10622

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

philschmid
Copy link

What does this PR do?

This PR adds an onnx_model_wav2vec2.py file to enable fusion optimization support for Wav2Vec2 models in Hugging Face Transformers. It also updates the optimizer.py and adds the model_type to it.

I made the changes based on the latest PR for turing:

To enable support for wav2vec2 i copied the onnx_model_bart and made the required changes for wav2vec2. I wasn't sure if that's the right way to add support for a new model or not. Please let me know if we should do it differently. Since the onnx_model_wav2vec2 and onnx_model_bart are pretty similar except for some checks in the EncoderAttention. See below the diffs from onnx_model_bart. [Line 28-...]

reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
- reshape_qkv_2_path_3 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
- if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None or reshape_qkv_2_path_3 is None:
+ if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None:
    return False

_, gather_1, shape_1 = reshape_qkv_2_path_1
_, gather_2, shape_2 = reshape_qkv_2_path_2
- _, _, shape_3 = reshape_qkv_2_path_3

- if shape_1.input[0] != root_input or shape_2.input[0] != root_input or shape_3.input[0] != root_input:
+ if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
    return False

Here is how i tested it.

from pathlib import Path
from onnxruntime.transformers import optimizer,
from onnxruntime.transformers.fusion_options import FusionOptions
from  onnxruntime.transformers.onnx_model_wav2vec2 import Wav2vec2OnnxModel
import onnx


num_attention_heads = 12
hidden_size = 768

export_directory = "exports"
model_name = "wav2vec2"
model_path = Path("speech/exports/wav2vec2-base-960h.onnx")


def main():

    opt_output_path = Path(export_directory).joinpath(f"{model_name}_opt.onnx")

    onnx_model = onnx.load_model(model_path.as_posix())

    optimized_model = Wav2vec2OnnxModel(
        model=onnx_model,
        num_heads=num_attention_heads,
        hidden_size=hidden_size,
    )
    optimized_model.optimize()
    print(optimized_model.get_fused_operator_statistics())

    # save model
    optimized_model.save_model_to_file(opt_output_path.as_posix())
    print(f"optimized model saved at: {opt_output_path.absolute()}")


if __name__ == "__main__":
    main()

@ghost
Copy link

ghost commented Feb 22, 2022

CLA assistant check
All CLA requirements met.

@wangyems
Copy link
Contributor

wangyems commented Mar 2, 2022

@philschmid Thanks for your contribution! Great job!
As you mentioned there's only a few line changes compared to bart attention fusion, it might be better to modify directly in onnx_model_bart.py and add "wav2vec2": (BartOnnxModel, "pytorch", 1) in optimizer.py. In general, we add new classes when the fusion logic changes a lot.

@philschmid
Copy link
Author

Thanks for the response @wangyems. So if i understand you correctly, you suggest creating a Wav2vec2OnnxModel and FusionWav2vec2EncoderAttention in the onnx_model_bart.py? or would you add some complex checks in the bart attention class?

@wangyems
Copy link
Contributor

wangyems commented Mar 3, 2022

Thanks for the response @wangyems. So if i understand you correctly, you suggest creating a Wav2vec2OnnxModel and FusionWav2vec2EncoderAttention in the onnx_model_bart.py? or would you add some complex checks in the bart attention class?

I suggest handling these small differences you mentioned in onnx_model_bart so that in the class of onnx_model_bart both bart and Wav2vec are supported.

@philschmid
Copy link
Author

@wangyems i added it to the onnx_model_bart.py

wangyems
wangyems previously approved these changes Mar 9, 2022
@philschmid
Copy link
Author

What do we need to do to get the CI running and merged?

@microsoft microsoft deleted a comment from azure-pipelines bot Mar 10, 2022
@wangyems
Copy link
Contributor

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline, ONNX Runtime Web CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@wangyems
Copy link
Contributor

/azp run Windows GPU TensorRT CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, onnxruntime-python-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 6 pipeline(s).

@chausner
Copy link
Contributor

@philschmid
Copy link
Author

@philschmid Thanks a lot for this contribution!

https://github.com/microsoft/onnxruntime/blob/8255ecbfb4511d5d998a9edac814dfbadf3bb13f/onnxruntime/python/tools/transformers/README.md should probably be updated before merging.

I added Wav2vec to the list of supported models

@chausner-audeering
Copy link
Contributor

chausner-audeering commented Mar 24, 2022

I can confirm attention fusion works for e.g. wav2vec2-base-960h. I did some performance measurements on CPU and I am not seeing any notable performance difference between the unoptimized and optimized model variants. I do see a much reduced node graph in Netron when comparing the two models, though.

Is the purpose of attention fusion purely to reduce the complexity of the graph or should there be a measureable performance improvement during inference?

@chausner-audeering
Copy link
Contributor

I can confirm attention fusion works for e.g. wav2vec2-base-960h.

Attention fusion fails when enabling extended optimizations, though (--opt_level 2):

is_fully_optimized: Attention not fused

Not sure if that is to be expected.

@JeffreyWardman
Copy link

Hi all, is there any update on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants