-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix overflow when training mDeberta in fp16 #24116
Conversation
…ully allows for fp16 training of mdeberta
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Thanks a lot for opening PR, and fixing a very old issue!
We might have problems with this with quantization, as the errors are probably gonna be different. Could you check if using torch.float16
or load_in_8bits
gives the same outputs. (this might fix training but maybe break inference? let's just check that everything runs correctly!)
See #22444
The documentation is not available anymore as the PR was closed or merged. |
I used this code block to check results.
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from transformers.pipelines import QuestionAnsweringPipeline
tokenizer = AutoTokenizer.from_pretrained("sjrhuschlee/mdeberta-v3-base-squad2")
model = AutoModelForQuestionAnswering.from_pretrained(
"sjrhuschlee/mdeberta-v3-base-squad2",
# torch_dtype=torch.float16,
# torch_dtype=torch.bfloat16,
# load_in_8bit=True,
)
pipe = QuestionAnsweringPipeline(model, tokenizer, device=torch.device("cuda:0")) # device=... was removed for 8bit Running on Main Branch Running the above code using # with torch.float16
pipe = QuestionAnsweringPipeline(model, tokenizer, device=torch.device("cuda:0"))
# [] Running with # with torch.bfloat16
pipe = QuestionAnsweringPipeline(model, tokenizer, device=torch.device("cuda:0"))
# {'score': 0.98369300365448, 'start': 33, 'end': 41, 'answer': ' Berlin.'}
# with torch.float32
pipe = QuestionAnsweringPipeline(model, tokenizer, device=torch.device("cuda:0"))
# {'score': 0.9850791096687317, 'start': 33, 'end': 41, 'answer': ' Berlin.'} Also running in # with load_in_8bit=True
pipe = QuestionAnsweringPipeline(model, tokenizer)
# {'score': 0.9868391752243042, 'start': 33, 'end': 41, 'answer': ' Berlin.'} Running on the PR # with torch.float16
pipe = QuestionAnsweringPipeline(model, tokenizer, device=torch.device("cuda:0"))
# {'score': 0.9848804473876953, 'start': 33, 'end': 41, 'answer': ' Berlin.'}
# with torch.bfloat16
pipe = QuestionAnsweringPipeline(model, tokenizer, device=torch.device("cuda:0"))
# {'score': 0.9841369986534119, 'start': 33, 'end': 41, 'answer': ' Berlin.'}
# with torch.float32
pipe = QuestionAnsweringPipeline(model, tokenizer, device=torch.device("cuda:0"))
# {'score': 0.9850791096687317, 'start': 33, 'end': 41, 'answer': ' Berlin.'}
# with load_in_8bit=True
pipe = QuestionAnsweringPipeline(model, tokenizer)
# {'score': 0.9870386719703674, 'start': 33, 'end': 41, 'answer': ' Berlin.'} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! seems like a very subtle but effective change! Pinging @amyeroberts for a second pair of eyes
I also noticed that the TF implementation in DebertaV2 has the same line transformers/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py Lines 678 to 679 in 2e2088f
I'm not too familiar with TF though so I'm not sure if this change should be made there as well. |
@sjrl To the best of my knowledge, we don't support training in fp16 in TF, so less of a risk here. I'd be pro updating in TF, so that the implementations are aligned and it's potentially safer. cc @Rocketknight1 for his thoughts. |
Yes, we support mixed-precision float16/bfloat16 training in TensorFlow, but in general we still expect a 'master' copy of the weights to remain in float32. We're planning some exploration to see if we can get Keras to accept full (b)float16 training, but it might require some refactoring! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@Rocketknight1 should I go ahead and update the TF implementation as well then? |
@sjrl Yes please! Better numerical stability will be nice to have once we've enabled full float16 training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this contribution and making our models more stable ❤️
@sjrl - Are there any other changes to add? Otherwise I think we're good to merge :) |
@amyeroberts You're welcome, and that's it for the changes! |
* Porting changes from https://github.com/microsoft/DeBERTa/ that hopefully allows for fp16 training of mdeberta * Updates to deberta modeling from microsoft repo * Performing some cleanup * Undoing changes that weren't necessary * Undoing float calls * Minimally change the p2c block * Fix error * Minimally changing the c2p block * Switch to torch sqrt * Remove math * Adding back the to calls to scale * Undoing attention_scores change * Removing commented out code * Updating modeling_sew_d.py to satisfy utils/check_copies.py * Missed changed * Further reduce changes needed to get fp16 working * Reverting changes to modeling_sew_d.py * Make same change in TF
What does this PR do?
Fixes microsoft/DeBERTa#77 (issue about transformers opened in Microsoft repo)
This issue was originally raised in the https://github.com/microsoft/DeBERTa repo which had to do with mDeberta not being able to be trained using fp16. A fix for this was implemented in the Microsoft repo by @BigBird01 but did not yet make it to HuggingFace. I was interested in training mDeberta models on small hardware (e.g. a 3070, T4) so I updated the HF implementation with the changes from the Microsoft repo. I tried to only bring over the minimal changes needed to get the fp16 training to work.
I checked that existing tests passed and also used this code to successfully train an mDeberta model in fp16 on Squad2 that can be found here which is not currently possible with the main branch of transformers. I'm unsure if there is a good way to add an additional test to make sure mDeberta-V3 training works in fp16 in the CI.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
Hey, based on the recommendations from the PR template (and git blame) I decided to tag @ArthurZucker and @sgugger in case you may be interested.