Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BetterTransformer integration for Detr #1022

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ You can now use this feature in 🤗 Optimum together with Transformers and use
In the 2.0 version, PyTorch includes a scaled dot-product attention function (SDPA) as part of `torch.nn.functional`. This function encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the [official documentation](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) in detail for more information.
We provide an integration with `BetterTransforer` API to use this function in 🤗 Optimum, so that you can convert any supported 🤗 Transformers model to call the `scaled_dot_product_attention` function when relevant.


### Supported models

The list of supported model below:
Expand All @@ -38,6 +37,7 @@ The list of supported model below:
- [Data2VecText](https://arxiv.org/abs/2202.03555)
- [DistilBert](https://arxiv.org/abs/1910.01108)
- [DeiT](https://arxiv.org/abs/2012.12877)
- [Detr](https://arxiv.org/abs/2005.12872)
- [Electra](https://arxiv.org/abs/2003.10555)
- [Ernie](https://arxiv.org/abs/1904.09223)
- [FSMT](https://arxiv.org/abs/1907.06616)
Expand Down Expand Up @@ -80,20 +80,36 @@ In order to use the `BetterTransformer` API just run the following commands:
>>> model_hf = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
>>> model = BetterTransformer.transform(model_hf, keep_original_model=True)
```

You can leave `keep_original_model=False` in case you want to overwrite the current model with its `BetterTransformer` version.

More details on `tutorials` section to deeply understand how to use it, or check the [Google colab demo](https://colab.research.google.com/drive/1Lv2RCG_AT6bZNdlL1oDDNNiwBBuirwI-?usp=sharing)!


<div class="mt-10">
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="./tutorials/convert"
><div class="w-full text-center bg-gradient-to-br from-blue-400 to-blue-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">Tutorials</div>
<p class="text-gray-700">Learn the basics and become familiar with 🤗 and `BetterTransformer` integration. Start here if you are using 🤗 Optimum for the first time!</p>
<a
class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg"
href="./tutorials/convert"
>
<div class="w-full text-center bg-gradient-to-br from-blue-400 to-blue-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">
Tutorials
</div>
<p class="text-gray-700">
Learn the basics and become familiar with 🤗 and `BetterTransformer`
integration. Start here if you are using 🤗 Optimum for the first time!
</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="./tutorials/contribute"
><div class="w-full text-center bg-gradient-to-br from-indigo-400 to-indigo-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">How-to guides</div>
<p class="text-gray-700">You want to add your own model for `BetterTransformer` support? Start here to check the contribution guideline!</p>
<a
class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg"
href="./tutorials/contribute"
>
<div class="w-full text-center bg-gradient-to-br from-indigo-400 to-indigo-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">
How-to guides
</div>
<p class="text-gray-700">
You want to add your own model for `BetterTransformer` support? Start
here to check the contribution guideline!
</p>
</a>
</div>
</div>
2 changes: 2 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BartEncoderLayerBetterTransformer,
BertLayerBetterTransformer,
CLIPLayerBetterTransformer,
DetrEncoderLayerBetterTransformer,
DistilBertLayerBetterTransformer,
FSMTEncoderLayerBetterTransformer,
MBartEncoderLayerBetterTransformer,
Expand All @@ -60,6 +61,7 @@ class BetterTransformerManager:
"codegen": {"CodeGenAttention": CodegenAttentionLayerBetterTransformer},
"data2vec-text": {"Data2VecTextLayer": BertLayerBetterTransformer},
"deit": {"DeiTLayer": ViTLayerBetterTransformer},
"detr": {"DetrEncoderLayer": DetrEncoderLayerBetterTransformer},
"distilbert": {"TransformerBlock": DistilBertLayerBetterTransformer},
"electra": {"ElectraLayer": BertLayerBetterTransformer},
"ernie": {"ErnieLayer": BertLayerBetterTransformer},
Expand Down
119 changes: 117 additions & 2 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple

import torch
import torch.nn as nn
from transformers.models.detr.modeling_detr import DetrEncoderLayer

from .base import BetterTransformerBaseLayer

Expand Down Expand Up @@ -407,6 +408,120 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
return (hidden_states,)


class DetrEncoderLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module):
def __init__(self, detr_layer: DetrEncoderLayer, config: PretrainedConfig):
super().__init__(config)
r"""
A simple conversion of the `DetrEncoderLayer` to its `BetterTransformer` implementation.

Args:
der (`torch.nn.Module`):
The original `BartEncoderLayer` where the weights needs to be retrieved.
"""
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
detr_layer.self_attn.q_proj.weight,
detr_layer.self_attn.k_proj.weight,
detr_layer.self_attn.v_proj.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
detr_layer.self_attn.q_proj.bias,
detr_layer.self_attn.k_proj.bias,
detr_layer.self_attn.v_proj.bias,
]
)
)

# out proj layer
self.out_proj_weight = detr_layer.self_attn.out_proj.weight
self.out_proj_bias = detr_layer.self_attn.out_proj.bias

# Linear layer 1
self.linear1_weight = detr_layer.fc1.weight
self.linear1_bias = detr_layer.fc1.bias

# Linear layer 2
self.linear2_weight = detr_layer.fc2.weight
self.linear2_bias = detr_layer.fc2.bias

# layer norm 1
self.norm1_eps = detr_layer.self_attn_layer_norm.eps
self.norm1_weight = detr_layer.self_attn_layer_norm.weight
self.norm1_bias = detr_layer.self_attn_layer_norm.bias

# layer norm 2
self.norm2_eps = detr_layer.final_layer_norm.eps
self.norm2_weight = detr_layer.final_layer_norm.weight
self.norm2_bias = detr_layer.final_layer_norm.bias

# model hyper parameter
self.num_heads = detr_layer.self_attn.num_heads
self.embed_dim = detr_layer.self_attn.embed_dim

# last step
self.is_last_layer = False
self.norm_first = True

self.original_layers_mapping = {
"in_proj_weight": ["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"],
"in_proj_bias": ["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"],
"out_proj_weight": "self_attn.out_proj.weight",
"out_proj_bias": "self_attn.out_proj.bias",
"linear1_weight": "fc1.weight",
"linear1_bias": "fc1.bias",
"linear2_weight": "fc2.weight",
"linear2_bias": "fc2.bias",
"norm1_eps": "self_attn_layer_norm.eps",
"norm1_weight": "self_attn_layer_norm.weight",
"norm1_bias": "self_attn_layer_norm.bias",
"norm2_eps": "final_layer_norm.eps",
"norm2_weight": "final_layer_norm.weight",
"norm2_bias": "final_layer_norm.bias",
}

self.validate_bettertransformer()

def forward(
self, hidden_states: torch.tensor, attention_mask: Optional[torch.Tensor] = None, *_, **__
) -> Tuple[torch.Tensor]:
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)


class MBartEncoderLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module):
def __init__(self, mbart_layer, config):
r"""
Expand Down Expand Up @@ -639,7 +754,7 @@ def forward(self, x, attn_mask, head_mask=None, output_attentions=None, *_):
attn_mask = torch.reshape(attn_mask, (attn_mask.shape[0], attn_mask.shape[-1]))
seqlen = attn_mask.shape[1]
lengths = torch.sum(~attn_mask, 1)
if not all([l == seqlen for l in lengths]):
if not all(l == seqlen for l in lengths):
x = torch._nested_tensor_from_mask(x, attn_mask)
attn_mask = None

Expand Down
2 changes: 1 addition & 1 deletion tests/bettertransformer/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCas
r"""
Testing suite for Vision Models - tests all the tests defined in `BetterTransformersTestMixin`
"""
SUPPORTED_ARCH = ["clip", "clip_text_model", "deit", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]
SUPPORTED_ARCH = ["clip", "clip_text_model", "deit", "detr", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]

def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preprocessor_kwargs):
if model_type == "vilt":
Expand Down