Skip to content

Commit

Permalink
test Blip2ForConditionalGeneration using float16
Browse files Browse the repository at this point in the history
  • Loading branch information
jpizarrom committed Aug 28, 2023
1 parent 7532d7d commit ff5db0c
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion tests/pipelines/test_pipelines_visual_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers.pipelines import pipeline
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
Expand All @@ -29,6 +30,10 @@
from .test_pipelines_common import ANY


if is_torch_available():
import torch


if is_vision_available():
from PIL import Image
else:
Expand Down Expand Up @@ -86,6 +91,7 @@ def test_small_model_pt(self):
)

@require_torch
@require_torch_gpu
def test_small_model_pt_blip2(self):
vqa_pipeline = pipeline(
"visual-question-answering", model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration"
Expand All @@ -112,6 +118,23 @@ def test_small_model_pt_blip2(self):
[[{"answer": ANY(str)}]] * 2,
)

vqa_pipeline = pipeline(
"visual-question-answering",
model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration",
model_kwargs={"torch_dtype": torch.float16},
device=0,
)
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
self.assertEqual(vqa_pipeline.model.vision_model.dtype, torch.float16)

outputs = vqa_pipeline(image=image, question=question)

self.assertEqual(
outputs,
[{"answer": ANY(str)}],
)

@slow
@require_torch
def test_large_model_pt(self):
Expand All @@ -138,9 +161,18 @@ def test_large_model_pt(self):
)

@slow
@require_torch
@require_torch_gpu
def test_large_model_pt_blip2(self):
vqa_pipeline = pipeline("visual-question-answering", model="Salesforce/blip2-opt-2.7b")
vqa_pipeline = pipeline(
"visual-question-answering",
model="Salesforce/blip2-opt-2.7b",
model_kwargs={"torch_dtype": torch.float16},
device=0,
)
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)

image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
question = "Question: how many cats are there? Answer:"

Expand Down

0 comments on commit ff5db0c

Please sign in to comment.