Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MBart support for BetterTransformer #516

Merged
merged 10 commits into from
Nov 30, 2022

Conversation

ravenouse
Copy link
Contributor

@ravenouse ravenouse commented Nov 26, 2022

This PR adds the MBartEncoderLayerBetterTransformer class to support MBart.

During testing, the code is runnable but it yields different results compared with the original transformer model, as shown in the below picture. The specific models I tested are facebook/mbart-large-50 and facebook/mbart-large-cc25. The downstream task I tested is fill masks

截屏2022-11-26 下午2 17 04
截屏2022-11-26 下午2 17 19

Could you tell me what I can do to solve this problem and what other tests I need to run?

Thank you so much!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ravenouse !
Thanks a lot for your PR! Glad that the conversion worked! We're almost there!
It seems that MBart uses pre-attention layer norm, could you try to set the attribute to True (as suggested on the suggestion below)
Let me know how it goes!

optimum/bettertransformer/models/encoder_models.py Outdated Show resolved Hide resolved
@ravenouse
Copy link
Contributor Author

Hi @ravenouse ! Thanks a lot for your PR! Glad that the conversion worked! We're almost there! It seems that MBart uses pre-attention layer norm, could you try to set the attribute to True (as suggested on the suggestion below) Let me know how it goes!

Hi Younes, thank you so much for the advice! It worked! This time the bt pipeline yields the exactly same results with
the original transformer pipeline!

截屏2022-11-26 下午5 26 31

Please let me know what else needed to be done!

Thanks again!

@younesbelkada
Copy link
Contributor

younesbelkada commented Nov 27, 2022

Very glad it worked @ravenouse ! Can't wait to see this PR to be merged!
Let me guide you step by step regarding the immediate next steps:

Step 1: Finish the integration for MBart

  • Since the refactoring for encoder-decoder based models tests has been recently addressed, adding the tests for MBart should be very easy. Just add a line with "hf-internal-testing/tiny-random-MBartModel" on the list ALL_ENCODER_DECODER_MODELS_TESTS here - push this change to see if the test pass (it should pass if everything works as expected , or you can just run the command pytest tests/bettertransformer/test_bettertransformer_encoder.py::BetterTransformersEncoderDecoderTest)
  • Add a line stating that MBart is supported for BetterTransformer on the documentation. The addition should go here (make sure to respect the alphabetical order)

Step 2: Add the integration for NLLB

  • NLLB is a recent encoder-decoder model from Meta AI that is integrated in transformers. Since NLLB uses the same encoder layer as MBart (see this line) the addition is very easy. Just add the line "M2M100EncoderLayer":MBartEncoderLayerBetterTransformer here.
  • Once the model is integrated as follows. Follow the same procedure as Step 1, and you can use the model hf-internal-testing/tiny-random-nllb for the integration tests ;)
    Let me know if anything is unclear! Again thanks a lot for your great contribution @ravenouse 💪

@ravenouse
Copy link
Contributor Author

Hi Younes, thank you so much for your detailed instructions! I have made all changes you mentioned above and I have run the pytest twice locally.

截屏2022-11-28 下午7 04 58

Please let me know what else is needed to be done.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much @ravenouse !
I left a small comment, could you please update ALL_ENCODER_DECODER_MODELS instead so that the tests will run on the new models? Of course let me know how it goes!

Comment on lines 47 to 48
"hf-internal-testing/tiny-random-MBartModel",
"hf-internal-testing/tiny-random-nllb",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you have to move them to ALL_ENCODER_DECODER_MODELS: the test pytest tests/bettertransformer/test_bettertransformer_encoder.py::BetterTransformersEncoderDecoderTest will only run on the models listed on ALL_ENCODER_DECODER_MODELS ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @younesbelkada. Thank you so much for the explanations! Now I have a better understanding of what's going on in the test files. I have moved the two test models to the right list. I run the pytest again and pass it.
截屏2022-11-29 下午7 21 51
Please let me know what else I can do!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 30, 2022

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada
Copy link
Contributor

Thanks a lot!
Could you just run the styling (run "make style") push the changes and we should be good to merge! 💪🏻

@ravenouse
Copy link
Contributor Author

Hi @younesbelkada. I have run the make style command and reformatted the code. Please let me know what you think about it!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great!
Thanks a lot for your clean implementation of M2M100 support for BetterTransformer 💪
Looking forward to your next contributions ;)

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@younesbelkada younesbelkada merged commit 31eb67c into huggingface:main Nov 30, 2022
@ravenouse
Copy link
Contributor Author

Hi @younesbelkada and @michaelbenayoun !
Thank you so much for all the help provided! It is my first PR and it is really an amazing and extremely beneficial experience for me!
I am wondering if I can choose another model to work on. Namely, the ASTLayer. I think the progress will be more smooth..

@younesbelkada
Copy link
Contributor

Thanks so much @ravenouse !
Sure yes, just make sure to notify other contributors in huggingface/transformers#20372 too ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants