Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flava model better transformers #907

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ The list of supported model below:
- [DeiT](https://arxiv.org/abs/2012.12877)
- [Electra](https://arxiv.org/abs/2003.10555)
- [Ernie](https://arxiv.org/abs/1904.09223)
- [Flava](https://arxiv.org/abs/2112.04482)
- [FSMT](https://arxiv.org/abs/1907.06616)
- [GPT2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
- [GPT-j](https://huggingface.co/EleutherAI/gpt-j-6B)
Expand Down
2 changes: 2 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
BertLayerBetterTransformer,
CLIPLayerBetterTransformer,
DistilBertLayerBetterTransformer,
FlavaLayerBetterTransformer,
FSMTEncoderLayerBetterTransformer,
MBartEncoderLayerBetterTransformer,
ProphetNetEncoderLayerBetterTransformer,
Expand Down Expand Up @@ -61,6 +62,7 @@ class BetterTransformerManager:
"distilbert": {"TransformerBlock": DistilBertLayerBetterTransformer},
"electra": {"ElectraLayer": BertLayerBetterTransformer},
"ernie": {"ErnieLayer": BertLayerBetterTransformer},
"flava": {"FlavaLayer": FlavaLayerBetterTransformer},
"fsmt": {"EncoderLayer": FSMTEncoderLayerBetterTransformer},
"gpt2": {"GPT2Attention": GPT2AttentionLayerBetterTransformer},
"gptj": {"GPTJAttention": GPTJAttentionLayerBetterTransformer},
Expand Down
119 changes: 119 additions & 0 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,125 @@ def forward(self, hidden_states, attention_mask, **__):
return (hidden_states,)


class FlavaLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module):
def __init__(self, flava_layer, config):
r"""
A simple conversion of the FlavaLayer to its `BetterTransformer` implementation.
Args:
flava_layer (`torch.nn.Module`):
The original `FlavaLayer` where the weights needs to be retrieved.
"""
super().__init__(config.image_config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
flava_layer.attention.attention.query.weight,
flava_layer.attention.attention.key.weight,
flava_layer.attention.attention.value.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
flava_layer.attention.attention.query.bias,
flava_layer.attention.attention.key.bias,
flava_layer.attention.attention.value.bias,
]
)
)

# Out proj layer
self.out_proj_weight = flava_layer.attention.output.dense.weight
self.out_proj_bias = flava_layer.attention.output.dense.bias

# Linear layer 1
self.linear1_weight = flava_layer.intermediate.dense.weight
self.linear1_bias = flava_layer.intermediate.dense.bias

# Linear layer 2
self.linear2_weight = flava_layer.output.dense.weight
self.linear2_bias = flava_layer.output.dense.bias

# Layer norm 1
self.norm1_eps = flava_layer.layernorm_before.eps
self.norm1_weight = flava_layer.layernorm_before.weight
self.norm1_bias = flava_layer.layernorm_before.bias

# Layer norm 2
self.norm2_eps = flava_layer.layernorm_after.eps
self.norm2_weight = flava_layer.layernorm_after.weight
self.norm2_bias = flava_layer.layernorm_after.bias

# Model hyper parameters
self.num_heads = flava_layer.attention.attention.num_attention_heads
self.embed_dim = int(flava_layer.attention.attention.attention_head_size * self.num_heads)

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False
self.norm_first = True

self.original_layers_mapping = {
"in_proj_weight": [
"attention.attention.query.weight",
"attention.attention.key.weight",
"attention.attention.value.weight",
],
"in_proj_bias": [
"attention.attention.query.bias",
"attention.attention.key.bias",
"attention.attention.value.bias",
],
"out_proj_weight": "attention.output.dense.weight",
"out_proj_bias": "attention.output.dense.bias",
"linear1_weight": "intermediate.dense.weight",
"linear1_bias": "intermediate.dense.bias",
"linear2_weight": "output.dense.weight",
"linear2_bias": "output.dense.bias",
"norm1_eps": "layernorm_before.eps",
"norm1_weight": "layernorm_before.weight",
"norm1_bias": "layernorm_before.bias",
"norm2_eps": "layernorm_after.eps",
"norm2_weight": "layernorm_after.weight",
"norm2_bias": "layernorm_after.bias",
}

self.validate_bettertransformer()

def forward(self, hidden_states, *_, **__):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)

class FSMTEncoderLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module):
def __init__(self, fsmt_layer, config):
r"""
Expand Down
2 changes: 1 addition & 1 deletion tests/bettertransformer/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCas
r"""
Testing suite for Vision Models - tests all the tests defined in `BetterTransformersTestMixin`
"""
SUPPORTED_ARCH = ["clip", "clip_text_model", "deit", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]
SUPPORTED_ARCH = ["clip", "clip_text_model", "deit", "vilt", "vit", "vit_mae", "vit_msn", "yolos", "flava"]

def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preprocessor_kwargs):
if model_type == "vilt":
Expand Down
1 change: 1 addition & 0 deletions tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"distilbert": "hf-internal-testing/tiny-random-DistilBertModel",
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"ernie": "hf-internal-testing/tiny-random-ErnieModel",
"flava": "ybelkada/tiny-random-flava",
"fsmt": "hf-internal-testing/tiny-random-FSMTModel",
"gpt2": "hf-internal-testing/tiny-random-GPT2Model",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
Expand Down