Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add blip-2 to bettertransformer #1125

Merged
merged 7 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class BetterTransformerManager:
"bert-generation": {"BertGenerationLayer": BertLayerBetterTransformer},
"blenderbot": {"BlenderbotAttention": BlenderbotAttentionLayerBetterTransformer},
"camembert": {"CamembertLayer": BertLayerBetterTransformer},
"blip-2": {"T5Attention": T5AttentionLayerBetterTransformer},
"clip": {"CLIPEncoderLayer": CLIPLayerBetterTransformer},
"codegen": {"CodeGenAttention": CodegenAttentionLayerBetterTransformer},
"data2vec-text": {"Data2VecTextLayer": BertLayerBetterTransformer},
Expand Down Expand Up @@ -111,6 +112,8 @@ class BetterTransformerManager:
EXCLUDE_FROM_TRANSFORM = {
# clip's text model uses causal attention, that is most likely not supported in BetterTransformer
"clip": ["text_model"],
# blip-2's Q-former and vision model should not be identified as the last layers of the model
"blip-2": ["qformer.encoder.layer", "vision_model.encoder.layers"],
}

CAN_NOT_BE_SUPPORTED = {
Expand All @@ -133,6 +136,7 @@ class BetterTransformerManager:

NOT_REQUIRES_STRICT_VALIDATION = {
"blenderbot",
"blip-2",
"codegen",
"gpt2",
"gptj",
Expand Down
2 changes: 2 additions & 0 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def forward(self, *args, **kwargs):

class T5AttentionLayerBetterTransformer(BetterTransformerBaseLayer, T5Attention, torch.nn.Module):
def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
if hasattr(config, "text_config"):
config = config.text_config
super().__init__(config)

with torch.device("meta"):
Expand Down
4 changes: 2 additions & 2 deletions optimum/bettertransformer/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ def transform(
model_fast = deepcopy(model)
except RuntimeError:
raise ValueError(
f"The model {model.__class__.__name__} does not support `deepcopy` operation that is "
f"The model {model.__class__.__name__} does not support `deepcopy` operation that is"
" internally used to create a copy of the original model when using"
" `keep_original_model=True`. Please run the conversion with "
" `keep_original_model=True`. Please run the conversion with"
" `keep_original_model=False` and create a new copy of the original"
" model somewhere else."
)
Expand Down
14 changes: 11 additions & 3 deletions tests/bettertransformer/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCas
r"""
Testing suite for Vision Models - tests all the tests defined in `BetterTransformersTestMixin`
"""
SUPPORTED_ARCH = ["clip", "clip_text_model", "deit", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]
SUPPORTED_ARCH = ["blip-2", "clip", "clip_text_model", "deit", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]

def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preprocessor_kwargs):
if model_type == "vilt":
Expand All @@ -38,11 +38,13 @@ def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preproc
# Model takes image and text as input
processor = AutoProcessor.from_pretrained(model_id)
inputs = processor(images=image, text=text, return_tensors="pt")
elif model_type in ["clip", "clip_text_model"]:
elif model_type in ["blip-2", "clip", "clip_text_model"]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

if batch_size == 1:
if (
batch_size == 1 or model_type == "blip-2"
): # TODO setup preprocessor_kwargs with batch_size=1 for blip-2
text = ["a photo"]
else:
text = ["a photo"] + ["a photo of two big cats"] * (batch_size - 1)
Expand All @@ -51,6 +53,10 @@ def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preproc
# Model takes image and text as input
processor = AutoProcessor.from_pretrained(model_id)
inputs = processor(images=image, text=text, padding=padding, return_tensors="pt", **preprocessor_kwargs)

if model_type == "blip-2":
inputs["decoder_input_ids"] = inputs["input_ids"]

else:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
Expand All @@ -74,6 +80,8 @@ def test_raise_autocast(self, model_type: str):

@parameterized.expand(SUPPORTED_ARCH)
def test_raise_train(self, model_type: str):
if model_type in ["blip-2"]:
self.skipTest("can be trained")
model_id = MODELS_DICT[model_type]
self._test_raise_train(model_id, model_type=model_type)

Expand Down
16 changes: 12 additions & 4 deletions tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"bert": "hf-internal-testing/tiny-random-BertModel",
"bert-generation": "ybelkada/random-tiny-BertGenerationModel",
"blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",
"blip-2": "hf-internal-testing/tiny-random-Blip2Model",
"camembert": "hf-internal-testing/tiny-random-camembert",
"clip_text_model": "hf-internal-testing/tiny-random-clip-zero-shot-image-classification", # with quick_gelu
"clip": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", # with gelu
Expand Down Expand Up @@ -351,10 +352,17 @@ def _test_invert_model_logits(
flattened_output_bt = [out for j in range(len(output_bt[i])) for out in output_bt[i][j]]
flattened_output_hf = [out for j in range(len(output_hf[i])) for out in output_hf[i][j]]
for j in range(len(flattened_output_bt)):
self.assertTrue(
torch.allclose(flattened_output_bt[j], flattened_output_hf[j], atol=1e-4),
f" Maxdiff: {(flattened_output_bt[j] - flattened_output_hf[j]).abs().max()}",
)
if isinstance(flattened_output_bt[j], torch.Tensor):
self.assertTrue(
torch.allclose(flattened_output_bt[j], flattened_output_hf[j], atol=1e-4),
f" Maxdiff: {(flattened_output_bt[j] - flattened_output_hf[j]).abs().max()}",
)
elif isinstance(flattened_output_bt[j], tuple):
for k in range(len(flattened_output_bt[j])):
self.assertTrue(
torch.allclose(flattened_output_bt[j][k], flattened_output_hf[j][k], atol=1e-4),
f" Maxdiff: {(flattened_output_bt[j][k] - flattened_output_hf[j][k]).abs().max()}",
)


def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_size, pad_idx=0):
Expand Down