diff --git a/tests/pipelines/test_pipelines_visual_question_answering.py b/tests/pipelines/test_pipelines_visual_question_answering.py index 67872445abda51..9f17657edeb57f 100644 --- a/tests/pipelines/test_pipelines_visual_question_answering.py +++ b/tests/pipelines/test_pipelines_visual_question_answering.py @@ -18,6 +18,7 @@ from transformers.pipelines import pipeline from transformers.testing_utils import ( is_pipeline_test, + is_torch_available, nested_simplify, require_tf, require_torch, @@ -29,6 +30,10 @@ from .test_pipelines_common import ANY +if is_torch_available(): + import torch + + if is_vision_available(): from PIL import Image else: @@ -86,6 +91,7 @@ def test_small_model_pt(self): ) @require_torch + @require_torch_gpu def test_small_model_pt_blip2(self): vqa_pipeline = pipeline( "visual-question-answering", model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration" @@ -112,6 +118,23 @@ def test_small_model_pt_blip2(self): [[{"answer": ANY(str)}]] * 2, ) + vqa_pipeline = pipeline( + "visual-question-answering", + model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration", + model_kwargs={"torch_dtype": torch.float16}, + device=0, + ) + self.assertEqual(vqa_pipeline.model.device, torch.device(0)) + self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16) + self.assertEqual(vqa_pipeline.model.vision_model.dtype, torch.float16) + + outputs = vqa_pipeline(image=image, question=question) + + self.assertEqual( + outputs, + [{"answer": ANY(str)}], + ) + @slow @require_torch def test_large_model_pt(self): @@ -138,9 +161,18 @@ def test_large_model_pt(self): ) @slow + @require_torch @require_torch_gpu def test_large_model_pt_blip2(self): - vqa_pipeline = pipeline("visual-question-answering", model="Salesforce/blip2-opt-2.7b") + vqa_pipeline = pipeline( + "visual-question-answering", + model="Salesforce/blip2-opt-2.7b", + model_kwargs={"torch_dtype": torch.float16}, + device=0, + ) + self.assertEqual(vqa_pipeline.model.device, torch.device(0)) + self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16) + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" question = "Question: how many cats are there? Answer:"