-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for Tapas Model #520
Changes from 17 commits
7939e13
2db0d45
283186a
b874831
7c24a42
12d6154
14b35ca
9632f98
d2e6ebb
762a804
ba1079e
4047267
ac716cc
add72dd
bbd625e
fded9ac
7fcc616
579f889
fe547eb
1d64709
64e644c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,114 @@ | |
from .base import BetterTransformerBaseLayer | ||
|
||
|
||
class TapasLayerBetterTransformer(BetterTransformerBaseLayer): | ||
fxmarty marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, tapas_layer, config): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant to remove the class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I removed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On your branch, with optimum dev installed ( from transformers import AutoModel
from optimum.bettertransformer import BetterTransformer
model_id = "hf-internal-testing/tiny-random-TapasModel"
model = AutoModel.from_pretrained(model_id)
bt_model = BetterTransformer.transform(model) |
||
r""" | ||
A simple conversion of the TAPAS layer to its `BetterTransformer` implementation. | ||
|
||
Args: | ||
tapas_layer (`torch.nn.Module`): | ||
The original TAPAS Layer where the weights needs to be retrieved. | ||
""" | ||
super().__init__(config) | ||
# In_proj layer | ||
self.in_proj_weight = nn.Parameter( | ||
torch.cat( | ||
[ | ||
tapas_layer.attention.query.weight, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looking at the Tapas implementation, https://github.com/huggingface/transformers/blob/28247e78819ab9756b81f8df39611c333d099400/src/transformers/models/tapas/modeling_tapas.py#L442 , I think we need here and in the rest There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh........ thanks you for providing me with your insights! |
||
tapas_layer.attention.key.weight, | ||
tapas_layer.attention.value.weight, | ||
] | ||
) | ||
) | ||
self.in_proj_bias = nn.Parameter( | ||
torch.cat( | ||
[ | ||
tapas_layer.attention.query.bias, | ||
tapas_layer.attention.key.bias, | ||
tapas_layer.attention.value.bias, | ||
] | ||
) | ||
) | ||
|
||
# Out proj layer | ||
self.out_proj_weight = tapas_layer.attention.dense.weight | ||
self.out_proj_bias = tapas_layer.attention.dense.bias | ||
|
||
# Linear layer 1 | ||
self.linear1_weight = tapas_layer.ffn.weight | ||
self.linear1_bias = tapas_layer.ffn.bias | ||
|
||
# Linear layer 2 | ||
self.linear2_weight = tapas_layer.ffn_output.weight | ||
self.linear2_bias = tapas_layer.ffn_output.bias | ||
|
||
# Layer norm 1 | ||
self.norm1_eps = tapas_layer.attention.LayerNorm.eps | ||
self.norm1_weight = tapas_layer.attention.LayerNorm.weight | ||
self.norm1_bias = tapas_layer.attention.LayerNorm.bias | ||
|
||
# Layer norm 2 | ||
self.norm2_eps = tapas_layer.full_layer_layer_norm.eps | ||
self.norm2_weight = tapas_layer.full_layer_layer_norm.weight | ||
self.norm2_bias = tapas_layer.full_layer_layer_norm.bias | ||
|
||
# Model hyper parameters | ||
self.num_heads = tapas_layer.attention.num_attention_heads | ||
self.embed_dim = tapas_layer.attention.all_head_size | ||
|
||
# Last step: set the last layer to `False` -> this will be set to `True` when converting the model | ||
self.is_last_layer = False | ||
|
||
self.validate_bettertransformer() | ||
|
||
def forward(self, hidden_states, attention_mask, *_, **__): | ||
r""" | ||
This is just a wrapper around the forward function proposed in: | ||
https://github.com/huggingface/transformers/pull/19553 | ||
""" | ||
super().forward_checker() | ||
|
||
if hidden_states.is_nested: | ||
attention_mask = None | ||
|
||
if attention_mask is not None: | ||
# attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask | ||
# 0->false->keep this token -inf->true->mask this token | ||
attention_mask = attention_mask.bool() | ||
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) | ||
seqlen = attention_mask.shape[1] | ||
lengths = torch.sum(~attention_mask, 1) | ||
if not all([l == seqlen for l in lengths]): | ||
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) | ||
attention_mask = None | ||
|
||
hidden_states = torch._transformer_encoder_layer_fwd( | ||
hidden_states, | ||
self.embed_dim, | ||
self.num_heads, | ||
self.in_proj_weight, | ||
self.in_proj_bias, | ||
self.out_proj_weight, | ||
self.out_proj_bias, | ||
self.use_gelu, | ||
self.norm_first, | ||
self.norm1_eps, | ||
self.norm1_weight, | ||
self.norm1_bias, | ||
self.norm2_weight, | ||
self.norm2_bias, | ||
self.linear1_weight, | ||
self.linear1_bias, | ||
self.linear2_weight, | ||
self.linear2_bias, | ||
attention_mask, | ||
) | ||
if hidden_states.is_nested and self.is_last_layer: | ||
hidden_states = hidden_states.to_padded_tensor(0.0) | ||
return (hidden_states,) | ||
|
||
|
||
class AlbertLayerBetterTransformer(BetterTransformerBaseLayer): | ||
def __init__(self, albert_layer, config): | ||
r""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be removed