Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MBart support for BetterTransformer #516

Merged
merged 10 commits into from
Nov 30, 2022
2 changes: 2 additions & 0 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ The list of supported model below:
- [HuBERT](https://arxiv.org/pdf/2106.07447.pdf)
- [LayoutLM](https://arxiv.org/abs/1912.13318)
- [MarkupLM](https://arxiv.org/abs/2110.08518)
- [MBart](https://arxiv.org/abs/2001.08210)
- [M2M100](https://arxiv.org/abs/2010.11125)
- [RoBERTa](https://arxiv.org/abs/1907.11692)
- [Splinter](https://arxiv.org/abs/2101.00438)
- [ViLT](https://arxiv.org/abs/2102.03334)
Expand Down
3 changes: 3 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AlbertLayerBetterTransformer,
BartEncoderLayerBetterTransformer,
BertLayerBetterTransformer,
MBartEncoderLayerBetterTransformer,
DistilBertLayerBetterTransformer,
FSMTEncoderLayerBetterTransformer,
ViltLayerBetterTransformer,
Expand All @@ -43,6 +44,8 @@
"AlbertLayer": AlbertLayerBetterTransformer,
# Bart family
"BartEncoderLayer": BartEncoderLayerBetterTransformer,
"MBartEncoderLayer": MBartEncoderLayerBetterTransformer,
"M2M100EncoderLayer":MBartEncoderLayerBetterTransformer,
# "PLBartEncoderLayer": bart.BartEncoderLayerBetterTransformer,
# "MarianEncoderLayer": bart.BartEncoderLayerBetterTransformer,
# "TimeSeriesTransformerEncoderLayer": bart.BartEncoderLayerBetterTransformer,
Expand Down
109 changes: 108 additions & 1 deletion optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,114 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)

class MBartEncoderLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, mbart_layer, config):
r"""
A simple conversion of the `MBartEncoderLayer` to its `BetterTransformer` implementation.
Args:
mbart_layer (`torch.nn.Module`):
The original `MBartEncoderLayer` where the weights needs to be retrieved.
"""
super().__init__(config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
mbart_layer.self_attn.q_proj.weight,
mbart_layer.self_attn.k_proj.weight,
mbart_layer.self_attn.v_proj.weight,
]
)
)

self.in_proj_bias = nn.Parameter(
torch.cat(
[
mbart_layer.self_attn.q_proj.bias,
mbart_layer.self_attn.k_proj.bias,
mbart_layer.self_attn.v_proj.bias,
]
)
)

# Out proj layer
self.out_proj_weight = mbart_layer.self_attn.out_proj.weight
self.out_proj_bias = mbart_layer.self_attn.out_proj.bias

# Linear layer 1
self.linear1_weight = mbart_layer.fc1.weight
self.linear1_bias = mbart_layer.fc1.bias

# Linear layer 2
self.linear2_weight = mbart_layer.fc2.weight
self.linear2_bias = mbart_layer.fc2.bias

# Layer norm 1
self.norm1_eps = mbart_layer.self_attn_layer_norm.eps
self.norm1_weight = mbart_layer.self_attn_layer_norm.weight
self.norm1_bias = mbart_layer.self_attn_layer_norm.bias

# Layer norm 2
self.norm2_eps = mbart_layer.final_layer_norm.eps
self.norm2_weight = mbart_layer.final_layer_norm.weight
self.norm2_bias = mbart_layer.final_layer_norm.bias

# Model hyper parameters
self.num_heads = mbart_layer.self_attn.num_heads
self.embed_dim = mbart_layer.self_attn.embed_dim

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False
self.norm_first = True
self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()

if hidden_states.is_nested:
attention_mask = None

if attention_mask is not None:
# attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
# 0->false->keep this token -inf->true->mask this token
if len(attention_mask.shape) == 4:
attention_mask = attention_mask.squeeze(1)[:, 0]
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)
if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)

class DistilBertLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, bert_layer, config):
Expand Down Expand Up @@ -402,7 +510,6 @@ def __init__(self, bert_layer, config):

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False

self.validate_bettertransformer()

def forward(self, x, attn_mask, head_mask=None, output_attentions=None, *_):
Expand Down
2 changes: 2 additions & 0 deletions tests/bettertransformer/test_bettertransformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
"hf-internal-testing/tiny-random-MarkupLMModel",
"hf-internal-testing/tiny-random-BertModel",
"ybelkada/random-tiny-BertGenerationModel",
"hf-internal-testing/tiny-random-MBartModel",
"hf-internal-testing/tiny-random-nllb",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you have to move them to ALL_ENCODER_DECODER_MODELS: the test pytest tests/bettertransformer/test_bettertransformer_encoder.py::BetterTransformersEncoderDecoderTest will only run on the models listed on ALL_ENCODER_DECODER_MODELS ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @younesbelkada. Thank you so much for the explanations! Now I have a better understanding of what's going on in the test files. I have moved the two test models to the right list. I run the pytest again and pass it.
截屏2022-11-29 下午7 21 51
Please let me know what else I can do!

]

ALL_ENCODER_DECODER_MODELS_TO_TEST = [
Expand Down