-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Blip2 model in VQA pipeline #25532
Add Blip2 model in VQA pipeline #25532
Conversation
cc @amyeroberts and @younesbelkada |
Hi @amyeroberts and @younesbelkada, this PR is ready for review. Could you please take a look? Thanks :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great to me, I left one comment to make sure the slow test will not blow up GPU memory in our daily CI runners!
Let's also wait for amy and @Narsil 's review before merging
@slow | ||
@require_torch | ||
def test_large_model_pt_blip2(self): | ||
vqa_pipeline = pipeline("visual-question-answering", model="Salesforce/blip2-opt-2.7b") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vqa_pipeline = pipeline("visual-question-answering", model="Salesforce/blip2-opt-2.7b") | |
vqa_pipeline = pipeline("visual-question-answering", model="Salesforce/blip2-opt-2.7b", model_kwargs={"torch_dtype": torch.float16}, device=0) |
Make also sure to cast the input in torch.float16
inside _forward
if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @younesbelkada , thanks for your feedback :)
model_kwargs were updated in test_large_model_pt_blip2, as recommended by you, but i am not sure how to check if casting to torch.float16
inside _forward
is needed, could you please give me some hints about what should I check? Thanks
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
I'm not a big fan of using the names of the configs directly to detect if generative or not, I feel like using the ForXXX
should be a better hint.
We also used model.can_generate()
as a hint in other pipelines.
Pinging @ylacombe who used that flag. (Just FYI no need to do anything).
f4848e1
to
1754398
Compare
@Narsil, Thanks a lot for your feedback :) At the moment model.can_generate() return False for Blip2ForConditionalGeneration, that is the reason why I was following the proposal of this non merged PR https://github.com/huggingface/transformers/pull/23348/files#diff-620bada7977c3d0040ed961581379598e53a9ef02fdbb26c570cac738c279c0eR64 Maybe could it be expected that can_generate method returns True for Blip2ForConditionalGeneration? if this is the case, we could use it. (i will take a look on it) can_generate returns True for another model, does it make sense to do this in Blip2ForConditionalGeneration, or it could affect something else? transformers/src/transformers/models/speecht5/modeling_speecht5.py Lines 2782 to 2787 in 50573c6
|
9cce838
to
ff5db0c
Compare
I'm not the best person to comment on how The main thing about pipeline:
But the current code is acceptable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
Before approving let's get @gante's opinion on the best/canonical way to detect if the model should generate the answer or not within the pipeline
self.assertEqual( | ||
outputs, | ||
[{"answer": ANY(str)}], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can go on one line
self.assertEqual( | |
outputs, | |
[{"answer": ANY(str)}], | |
) | |
self.assertEqual(outputs, [{"answer": ANY(str)}]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding detection of generative models: can_generate()
was built precisely to confirm whether the model can safely call generate()
(in theory, all models can do it due to the inheritance structure, in practice only a few can use it). This includes pipelines uses 👍
However, I don't think we should overload the function in the class -- see my comment below, going to open a PR with a more general solution :)
Generalizable solution here ☝️ |
ff5db0c
to
d309aa4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean and minimal PR, I like it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this functionality!
Feel free to tweet/linkedin about it @jpizarrom and we'll amplify :) |
* Add Blip2 model in VQA pipeline * use require_torch_gpu for test_large_model_pt_blip2 * use can_generate in vqa pipeline * test Blip2ForConditionalGeneration using float16 * remove custom can_generate from Blip2ForConditionalGeneration
* Add Blip2 model in VQA pipeline * use require_torch_gpu for test_large_model_pt_blip2 * use can_generate in vqa pipeline * test Blip2ForConditionalGeneration using float16 * remove custom can_generate from Blip2ForConditionalGeneration
* Add Blip2 model in VQA pipeline * use require_torch_gpu for test_large_model_pt_blip2 * use can_generate in vqa pipeline * test Blip2ForConditionalGeneration using float16 * remove custom can_generate from Blip2ForConditionalGeneration
I use the newest library of transformers,but it still reports "The model 'Blip2ForConditionalGeneration' is not supported for vqa. Supported models are ['ViltForQuestionAnswering'].",so how can I use blip2 in pipeline to deal with the vqa task?Are there any test codes of BLIP2 in pipeline? |
Hi @RainyLayx i was able to run this sample with transformers==4.37.1 from transformers import pipeline
import requests
from PIL import Image
vqa_pipeline = pipeline("visual-question-answering", model="Salesforce/blip2-opt-2.7b")
image = Image.open(requests.get("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", stream=True).raw)
question = "Question: Is there a parrot? Answer:"
print(vqa_pipeline(image, question, top_k=1)) |
What does this PR do?
Add Blip2ForConditionalGeneration model in VisualQuestionAnsweringPipeline.
Fixes part of #21110 and is based on #23348 #21227 .
Who can review?
Hi @NielsRogge what do you think of this??
Thanks!
TODOs