Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llava half precision and autocast issues #29721

Merged

Conversation

frasermince
Copy link
Contributor

@frasermince frasermince commented Mar 19, 2024

What does this PR do?

Currently half precision training on LLaVa is broken due to the language embedding output always being full precision. This PR matches the behavior of the original LLaVa implementation by casting the result of language embedding to the image_features' dtype. See here for relevant line in the original repo.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Could you add a test e.g. like create_and_check_reformer_model_fp16_forward that we have for other models?

@frasermince
Copy link
Contributor Author

Thanks for adding this!

Could you add a test e.g. like create_and_check_reformer_model_fp16_forward that we have for other models?

Apologies on the wait on this. After implementing the test, it looks like the problem here is not quite what I thought it was. I will try to respond soon with either fixing the correct problem or closing this PR.

@frasermince frasermince force-pushed the frasermince/half-precision-llava-fix branch 2 times, most recently from 1f8208b to b4188e3 Compare April 9, 2024 15:06
@frasermince frasermince marked this pull request as draft April 9, 2024 18:05
@frasermince
Copy link
Contributor Author

After some further investigation, I tracked down exactly what was happening. .half() or .to(torch.bfloat16) work fine on the model. However, the issue arises when the model is run within with torch.autocast. This also includes if you are using a trainer with the bf16 or fp16 arg set to true since accelerate uses an autocast context behind the scenes. When run in autocast bf16 or fp16 block the language model embedding has full precision. However, the image_features are in half precision. Due to these being merged there is a type error. To fix this we need to either cast embeddings to half precision or image features to full. It made the most sense to me to cast the embeddings to half precision. Let me know if you prefer we go the other direction but this seemed the most efficient.

I also fixed the same issue on llava_next. In the process I discovered that there were NaNs occuring upon running llava_next in half precision. To fix this I used randn instead of empty similarly to the original llava initialization.

@frasermince frasermince changed the title Match original llava implentation by casting to pixel_values dtype. Fix llava half precision and autocast issues Apr 9, 2024

self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size, dtype=self.dtype))
self.image_newline = nn.Parameter(
torch.randn(config.text_config.hidden_size, dtype=config.torch_dtype or self.dtype)
Copy link
Contributor Author

@frasermince frasermince Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We needed some way to set this half precision in autocast environments. Let me know if this is incorrect. I am still unclear when to use torch_dtype and when not to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I should note I am not quite certain randn is correct here but the empty was causing NaNs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just from looking at this, I suspect the issue is that the param self.image_newline is being create as an empty tensor, but it's not then being initialized in _init_weights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right I will move this initialization to _init_weights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a second glance it looks sufficient to do it on initialization. The original repo just had a a separate function to initialize the vision tower.

@frasermince frasermince requested a review from amyeroberts April 9, 2024 19:00
@frasermince frasermince marked this pull request as ready for review April 9, 2024 19:00
model = LlavaForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had some concerns about putting this here given I haven't seen autocasting tested in the unit tests elsewhere in the repo. Let me know if you prefer this being tested through an integration test with accelerate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your intuition is correct :) We shouldn't need to do this autocasting here. Without it, I'm assuming it fails?

Could you specify what you're thinking re an accelerate integration test and the need for it? Other fp16 test don't specify a specific accelerate version, so not sure on what it would be addressing here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of the opposite. It works fine without autocasting in .half. The issue is when using a trainer with the fp16 or bf16 flag the model returns a type error. This uses autocasting behind the scenes through accelerate so I wrote these tests as the simplest case to capture this failure. I was not quite sure how to handle this in these tests given we are not testing autocasting behavior elsewhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, just to make sure I've understood, the issue arises when passing the model to Trainer? Does the following test pass?

    def create_and_check_llava_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
        model = LlavaForConditionalGeneration(config=config).to(torch_device).half().eval()
        output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"]
        self.parent.assertFalse(torch.isnan(output).any().item())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, the issue arises in passing the model to the trainer with the fp16 or bf16 flags. The test you included would work just fine. As far as I can tell this is due to how the model works when autocast is used within the trainer (indirectly through accelerate). I was able to replicate the same bug in this test using the autocast block as I see when I run the model through the trainer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, OK. I think we should do this in two steps. Adding an integration test for trainer & autocast for models sounds like a good idea. I suspect it might throw up quite a few things to address and the design of the test is important to make sure it's as lightweight as possible. By splitting up, we can add this fix in quickly and then iterate on the test design / fixing errors it throws.

What I would suggest is:

  • Keep the change to llava next and the create_and_check_llava_next_model_fp16_forward test in this PR
  • Open a new PR for adding an integration test. This would possibly sit under tests/trainer/test_trainer.py - @muellerzr will be able to advise here re where and design :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fully agree with this way forward @frasermince !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also document the test with some comments explaining why it's needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a plan! I'll get to work on that! Would we want this new integration test still namespaced under the model something like tests/llava/trainer/test_trainer.py or are you suggesting we add testing more generally to the trainer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe test/modeling/llava/test_trainer_llava.py

since for now its model specific?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this!

model = LlavaForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your intuition is correct :) We shouldn't need to do this autocasting here. Without it, I'm assuming it fails?

Could you specify what you're thinking re an accelerate integration test and the need for it? Other fp16 test don't specify a specific accelerate version, so not sure on what it would be addressing here


self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size, dtype=self.dtype))
self.image_newline = nn.Parameter(
torch.randn(config.text_config.hidden_size, dtype=config.torch_dtype or self.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just from looking at this, I suspect the issue is that the param self.image_newline is being create as an empty tensor, but it's not then being initialized in _init_weights


self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size, dtype=self.dtype))
self.image_newline = nn.Parameter(
torch.randn(config.text_config.hidden_size, dtype=config.torch_dtype or self.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just be taken from the instance

Suggested change
torch.randn(config.text_config.hidden_size, dtype=config.torch_dtype or self.dtype)
torch.randn(config.text_config.hidden_size, dtype=self.dtype)

@frasermince frasermince force-pushed the frasermince/half-precision-llava-fix branch from 24b4244 to 0db9a95 Compare April 15, 2024 20:02
@muellerzr
Copy link
Contributor

muellerzr commented Apr 16, 2024

I’ll need to dig into this when I’m back to my computer, but @amyeroberts just to describe how the autocast system works in accelerate quickly:

We wrap around the models’ .forward() with an autocast manager, so that the original weights remain in fp32. Not doing so leads to a large deal of precision and end accuracy issues, which is why we maintain this reference. It uses a pinch more memory (than pure doing model.half()) but for good reason. And then we convert the output back to fp32 at the end

If certain weights need to be cast for certain models, I think that can be okay and taken as a case-by-case basis, but I want to re-read this discussion carefully to understand why.

(And we should test training for a few epochs/steps to make sure accuracies etc remain as expected when using pretrained weights)

@frasermince
Copy link
Contributor Author

I’ll need to dig into this when I’m back to my computer, but @amyeroberts just to describe how the autocast system works in accelerate quickly:

We wrap around the models’ .forward() with an autocast manager, so that the original weights remain in fp32. Not doing so leads to a large deal of precision and end accuracy issues, which is why we maintain this reference. It uses a pinch more memory (than pure doing model.half()) but for good reason. And then we convert the output back to fp32 at the end

If certain weights need to be cast for certain models, I think that can be okay and taken as a case-by-case basis, but I want to re-read this discussion carefully to understand why.

(And we should test training for a few epochs/steps to make sure accuracies etc remain as expected when using pretrained weights)

Sounds good! Let me know if I can do anything to make the reason for these changes more clear!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think these are good tests how they are, let's please keep a close eye on it and see what other models have issues. We may want to change some things with Accelerate (like specifying certain layers to be in the precision fully) if tests show no loss in performance :)

@frasermince
Copy link
Contributor Author

Overall I think these are good tests how they are, let's please keep a close eye on it and see what other models have issues. We may want to change some things with Accelerate (like specifying certain layers to be in the precision fully) if tests show no loss in performance :)

Sounds good! I will also add the trainer test in a separate PR as requested.

@frasermince
Copy link
Contributor Author

@muellerzr before merging do you have any suggestions on ensuring accuracies remain as expected? Is it basically taking the assertions in some of the integration tests after a couple of training steps?

@muellerzr
Copy link
Contributor

muellerzr commented Apr 24, 2024

Is it basically taking the assertions in some of the integration tests after a couple of training steps?

Yes indeed! (Assertion or torch.close more than likely)

@frasermince frasermince force-pushed the frasermince/half-precision-llava-fix branch 2 times, most recently from 8e9b5b7 to 003c89d Compare April 25, 2024 02:44
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this and iterating with us to find a solution!

Only last thing is to run the slow model tests:

RUN_SLOW=1 pytest tests/models/llava tests/models/llava_next

Once they're all confirmed to be passing we're good to merge!

@frasermince
Copy link
Contributor Author

RUN_SLOW=1 pytest tests/models/llava tests/models/llava_next

There was one that seemed to be already failing before my changes. I took the liberty of changing the assertion slightly. Let me know if this is acceptable! There are a couple failing in main for me for llava_next as well but that would require me changing the expected logits. Not sure if that is acceptable or not.

@muellerzr
Copy link
Contributor

@amyeroberts leave to your discretion :)

@frasermince
Copy link
Contributor Author

Given slow the test was failing on main I changed the logits in the test for now. If that is wrong I can remove.

@amyeroberts
Copy link
Collaborator

@frasermince Great - thanks for checking! I'm going to be a bit annoying here - as the differences are so small it's quite possible they're just coming from running on different hardware / slightly different env setups. So we might need to revert.

@ydshieh How can we run the slow tests on this branch in the new github actions workflow to make sure the values are aligned?

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 30, 2024

@amyeroberts I am on it and should be ready soon (even possible today 🔥 )

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 30, 2024

It is in #30540 :-)

@amyeroberts
Copy link
Collaborator

@frasermince Thanks for your patience!

We've just merged in #30540 (thanks @ydshieh!) so should be able to run the slow tests on our runners. To trigger this could you:

  • Rebase on main to include the recent commit
  • Push an empty commit with the message [run-slow]llava,llava_next

@frasermince frasermince force-pushed the frasermince/half-precision-llava-fix branch from 4839ccf to f0ca873 Compare April 30, 2024 21:00
@amyeroberts
Copy link
Collaborator

Thanks for triggering the tests! OK, so we're hitting OOM errors, which is going to be a pain to try and resolve as it's a wider issue. What I suggest is:

  • We resolve the conflicts
  • We merge this in
  • If there's some issues with the slow tests which aren't oom, I'll handle in a follow up PR (cc @ydshieh - I'll be the person to ping)

@frasermince frasermince force-pushed the frasermince/half-precision-llava-fix branch 2 times, most recently from 3ef96e4 to 75b4a76 Compare May 1, 2024 15:46
@frasermince frasermince force-pushed the frasermince/half-precision-llava-fix branch from 75b4a76 to 750244c Compare May 1, 2024 15:49
@amyeroberts
Copy link
Collaborator

@frasermince Thanks for all of your patience iterating with us. We're good to merge!

@amyeroberts amyeroberts merged commit 5090ea3 into huggingface:main May 1, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants