From 255251f553ca80b5f6118b26f6a588e9057b2f7f Mon Sep 17 00:00:00 2001 From: ravenouse Date: Sat, 26 Nov 2022 00:51:11 -0700 Subject: [PATCH 1/9] added MBartEncoderLayerBetterTransformer --- optimum/bettertransformer/models/__init__.py | 1 + .../models/encoder_models.py | 108 ++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index fd9c7a1330c..a5b780cbe6c 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -43,6 +43,7 @@ "AlbertLayer": AlbertLayerBetterTransformer, # Bart family "BartEncoderLayer": BartEncoderLayerBetterTransformer, + "MBartEncoderLayer": MBartEncoderLayerBetterTransformer, # "PLBartEncoderLayer": bart.BartEncoderLayerBetterTransformer, # "MarianEncoderLayer": bart.BartEncoderLayerBetterTransformer, # "TimeSeriesTransformerEncoderLayer": bart.BartEncoderLayerBetterTransformer, diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index 1acea21f4e3..2f09671aa06 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -343,6 +343,114 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): hidden_states = hidden_states.to_padded_tensor(0.0) return (hidden_states,) +class MBartEncoderLayerBetterTransformer(BetterTransformerBaseLayer): + def __init__(self, mbart_layer, config): + r""" + A simple conversion of the `MBartEncoderLayer` to its `BetterTransformer` implementation. + Args: + mbart_layer (`torch.nn.Module`): + The original `MBartEncoderLayer` where the weights needs to be retrieved. + """ + super().__init__(config) + # In_proj layer + self.in_proj_weight = nn.Parameter( + torch.cat( + [ + mbart_layer.self_attn.q_proj.weight, + mbart_layer.self_attn.k_proj.weight, + mbart_layer.self_attn.v_proj.weight, + ] + ) + ) + + self.in_proj_bias = nn.Parameter( + torch.cat( + [ + mbart_layer.self_attn.q_proj.bias, + mbart_layer.self_attn.k_proj.bias, + mbart_layer.self_attn.v_proj.bias, + ] + ) + ) + + # Out proj layer + self.out_proj_weight = mbart_layer.self_attn.out_proj.weight + self.out_proj_bias = mbart_layer.self_attn.out_proj.bias + + # Linear layer 1 + self.linear1_weight = mbart_layer.fc1.weight + self.linear1_bias = mbart_layer.fc1.bias + + # Linear layer 2 + self.linear2_weight = mbart_layer.fc2.weight + self.linear2_bias = bart_layer.fc2.bias + + # Layer norm 1 + self.norm1_eps = mbart_layer.self_attn_layer_norm.eps + self.norm1_weight = mbart_layer.self_attn_layer_norm.weight + self.norm1_bias = mbart_layer.self_attn_layer_norm.bias + + # Layer norm 2 + self.norm2_eps = mbart_layer.final_layer_norm.eps + self.norm2_weight = mbart_layer.final_layer_norm.weight + self.norm2_bias = mbart_layer.final_layer_norm.bias + + # Model hyper parameters + self.num_heads = mbart_layer.self_attn.num_heads + self.embed_dim = mbart_layer.self_attn.embed_dim + + # Last step: set the last layer to `False` -> this will be set to `True` when converting the model + self.is_last_layer = False + + self.validate_bettertransformer() + + def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): + r""" + This is just a wrapper around the forward function proposed in: + https://github.com/huggingface/transformers/pull/19553 + """ + super().forward_checker() + + if hidden_states.is_nested: + attention_mask = None + + if attention_mask is not None: + # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask + # 0->false->keep this token -inf->true->mask this token + if len(attention_mask.shape) == 4: + attention_mask = attention_mask.squeeze(1)[:, 0] + attention_mask = attention_mask.bool() + attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1])) + seqlen = attention_mask.shape[1] + lengths = torch.sum(~attention_mask, 1) + if not all([l == seqlen for l in lengths]): + hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + return (hidden_states,) class DistilBertLayerBetterTransformer(BetterTransformerBaseLayer): def __init__(self, bert_layer, config): From d68a8f4b4fa20d409fef13c849be3c1c56959b4c Mon Sep 17 00:00:00 2001 From: ravenouse Date: Sat, 26 Nov 2022 01:35:32 -0700 Subject: [PATCH 2/9] fixed some bugs --- optimum/bettertransformer/models/__init__.py | 1 + optimum/bettertransformer/models/encoder_models.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index a5b780cbe6c..871a32f0e6f 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -17,6 +17,7 @@ AlbertLayerBetterTransformer, BartEncoderLayerBetterTransformer, BertLayerBetterTransformer, + MBartEncoderLayerBetterTransformer, DistilBertLayerBetterTransformer, FSMTEncoderLayerBetterTransformer, ViltLayerBetterTransformer, diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index 2f09671aa06..11b77c7d59a 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -383,7 +383,7 @@ def __init__(self, mbart_layer, config): # Linear layer 2 self.linear2_weight = mbart_layer.fc2.weight - self.linear2_bias = bart_layer.fc2.bias + self.linear2_bias = mbart_layer.fc2.bias # Layer norm 1 self.norm1_eps = mbart_layer.self_attn_layer_norm.eps From 478473f4c7950a15a8a6214ec6b17755561224eb Mon Sep 17 00:00:00 2001 From: ravenouse <85110830+ravenouse@users.noreply.github.com> Date: Sat, 26 Nov 2022 17:01:31 -0700 Subject: [PATCH 3/9] Update optimum/bettertransformer/models/encoder_models.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- optimum/bettertransformer/models/encoder_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index 11b77c7d59a..67e057b5815 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -401,7 +401,7 @@ def __init__(self, mbart_layer, config): # Last step: set the last layer to `False` -> this will be set to `True` when converting the model self.is_last_layer = False - + self.norm_first = True self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): From 89274c6829c711d21065febefb56cce7647fcb62 Mon Sep 17 00:00:00 2001 From: ravenouse Date: Mon, 28 Nov 2022 18:24:14 -0700 Subject: [PATCH 4/9] added test and doc for the mbart implementation --- docs/source/bettertransformer/overview.mdx | 1 + tests/bettertransformer/test_bettertransformer_encoder.py | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/source/bettertransformer/overview.mdx b/docs/source/bettertransformer/overview.mdx index 441075e57ec..d3854a76644 100644 --- a/docs/source/bettertransformer/overview.mdx +++ b/docs/source/bettertransformer/overview.mdx @@ -37,6 +37,7 @@ The list of supported model below: - [HuBERT](https://arxiv.org/pdf/2106.07447.pdf) - [LayoutLM](https://arxiv.org/abs/1912.13318) - [MarkupLM](https://arxiv.org/abs/2110.08518) +- [MBart](https://arxiv.org/abs/2001.08210) - [RoBERTa](https://arxiv.org/abs/1907.11692) - [Splinter](https://arxiv.org/abs/2101.00438) - [ViLT](https://arxiv.org/abs/2102.03334) diff --git a/tests/bettertransformer/test_bettertransformer_encoder.py b/tests/bettertransformer/test_bettertransformer_encoder.py index 4fe12b45849..4b3c30fc7ac 100644 --- a/tests/bettertransformer/test_bettertransformer_encoder.py +++ b/tests/bettertransformer/test_bettertransformer_encoder.py @@ -44,6 +44,7 @@ "hf-internal-testing/tiny-random-MarkupLMModel", "hf-internal-testing/tiny-random-BertModel", "ybelkada/random-tiny-BertGenerationModel", + "hf-internal-testing/tiny-random-MBartModel", ] ALL_ENCODER_DECODER_MODELS_TO_TEST = [ From b93a3ea2503ce3dcf7c68de451cabb1ac10c52f6 Mon Sep 17 00:00:00 2001 From: ravenouse Date: Mon, 28 Nov 2022 18:32:40 -0700 Subject: [PATCH 5/9] fixed the mistake made in the last push --- optimum/bettertransformer/models/encoder_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index 67e057b5815..550f8959bdf 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -510,7 +510,7 @@ def __init__(self, bert_layer, config): # Last step: set the last layer to `False` -> this will be set to `True` when converting the model self.is_last_layer = False - + self.norm_first = True self.validate_bettertransformer() def forward(self, x, attn_mask, head_mask=None, output_attentions=None, *_): From 58377cc53b8f98892fdb482fe7b699699c6c2590 Mon Sep 17 00:00:00 2001 From: ravenouse Date: Mon, 28 Nov 2022 18:41:15 -0700 Subject: [PATCH 6/9] fixed the mistakes in the previous steps --- optimum/bettertransformer/models/encoder_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index 550f8959bdf..bb3cd4226b4 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -510,7 +510,6 @@ def __init__(self, bert_layer, config): # Last step: set the last layer to `False` -> this will be set to `True` when converting the model self.is_last_layer = False - self.norm_first = True self.validate_bettertransformer() def forward(self, x, attn_mask, head_mask=None, output_attentions=None, *_): From 77b34676caa653da7ddf922717191fad60cea889 Mon Sep 17 00:00:00 2001 From: ravenouse Date: Mon, 28 Nov 2022 18:58:06 -0700 Subject: [PATCH 7/9] added test and doc for M2M100 --- docs/source/bettertransformer/overview.mdx | 1 + optimum/bettertransformer/models/__init__.py | 1 + tests/bettertransformer/test_bettertransformer_encoder.py | 1 + 3 files changed, 3 insertions(+) diff --git a/docs/source/bettertransformer/overview.mdx b/docs/source/bettertransformer/overview.mdx index d3854a76644..228fb812063 100644 --- a/docs/source/bettertransformer/overview.mdx +++ b/docs/source/bettertransformer/overview.mdx @@ -38,6 +38,7 @@ The list of supported model below: - [LayoutLM](https://arxiv.org/abs/1912.13318) - [MarkupLM](https://arxiv.org/abs/2110.08518) - [MBart](https://arxiv.org/abs/2001.08210) +- [M2M100](https://arxiv.org/abs/2010.11125) - [RoBERTa](https://arxiv.org/abs/1907.11692) - [Splinter](https://arxiv.org/abs/2101.00438) - [ViLT](https://arxiv.org/abs/2102.03334) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index 871a32f0e6f..90fd9d0c953 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -45,6 +45,7 @@ # Bart family "BartEncoderLayer": BartEncoderLayerBetterTransformer, "MBartEncoderLayer": MBartEncoderLayerBetterTransformer, + "M2M100EncoderLayer":MBartEncoderLayerBetterTransformer, # "PLBartEncoderLayer": bart.BartEncoderLayerBetterTransformer, # "MarianEncoderLayer": bart.BartEncoderLayerBetterTransformer, # "TimeSeriesTransformerEncoderLayer": bart.BartEncoderLayerBetterTransformer, diff --git a/tests/bettertransformer/test_bettertransformer_encoder.py b/tests/bettertransformer/test_bettertransformer_encoder.py index 4b3c30fc7ac..1501a80dcdf 100644 --- a/tests/bettertransformer/test_bettertransformer_encoder.py +++ b/tests/bettertransformer/test_bettertransformer_encoder.py @@ -45,6 +45,7 @@ "hf-internal-testing/tiny-random-BertModel", "ybelkada/random-tiny-BertGenerationModel", "hf-internal-testing/tiny-random-MBartModel", + "hf-internal-testing/tiny-random-nllb", ] ALL_ENCODER_DECODER_MODELS_TO_TEST = [ From 3c1b345240553aaf440007e430480395dcdd9607 Mon Sep 17 00:00:00 2001 From: ravenouse <85110830+ravenouse@users.noreply.github.com> Date: Wed, 30 Nov 2022 02:18:57 +0000 Subject: [PATCH 8/9] moved the test cases/models for the right list --- tests/bettertransformer/test_bettertransformer_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/bettertransformer/test_bettertransformer_encoder.py b/tests/bettertransformer/test_bettertransformer_encoder.py index 1501a80dcdf..f8f5f024474 100644 --- a/tests/bettertransformer/test_bettertransformer_encoder.py +++ b/tests/bettertransformer/test_bettertransformer_encoder.py @@ -44,13 +44,13 @@ "hf-internal-testing/tiny-random-MarkupLMModel", "hf-internal-testing/tiny-random-BertModel", "ybelkada/random-tiny-BertGenerationModel", - "hf-internal-testing/tiny-random-MBartModel", - "hf-internal-testing/tiny-random-nllb", ] ALL_ENCODER_DECODER_MODELS_TO_TEST = [ "hf-internal-testing/tiny-random-FSMTModel", "hf-internal-testing/tiny-random-BartModel", + "hf-internal-testing/tiny-random-MBartModel", + "hf-internal-testing/tiny-random-nllb", ] From d0720ccfe45e7401eaf61e107a6ab3f9c489fac5 Mon Sep 17 00:00:00 2001 From: ravenouse <85110830+ravenouse@users.noreply.github.com> Date: Wed, 30 Nov 2022 07:31:50 +0000 Subject: [PATCH 9/9] reformatted the code --- optimum/bettertransformer/models/__init__.py | 4 ++-- optimum/bettertransformer/models/encoder_models.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index 90fd9d0c953..0106fae977e 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -17,9 +17,9 @@ AlbertLayerBetterTransformer, BartEncoderLayerBetterTransformer, BertLayerBetterTransformer, - MBartEncoderLayerBetterTransformer, DistilBertLayerBetterTransformer, FSMTEncoderLayerBetterTransformer, + MBartEncoderLayerBetterTransformer, ViltLayerBetterTransformer, ViTLayerBetterTransformer, Wav2Vec2EncoderLayerBetterTransformer, @@ -45,7 +45,7 @@ # Bart family "BartEncoderLayer": BartEncoderLayerBetterTransformer, "MBartEncoderLayer": MBartEncoderLayerBetterTransformer, - "M2M100EncoderLayer":MBartEncoderLayerBetterTransformer, + "M2M100EncoderLayer": MBartEncoderLayerBetterTransformer, # "PLBartEncoderLayer": bart.BartEncoderLayerBetterTransformer, # "MarianEncoderLayer": bart.BartEncoderLayerBetterTransformer, # "TimeSeriesTransformerEncoderLayer": bart.BartEncoderLayerBetterTransformer, diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index bb3cd4226b4..8550d427f14 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -343,6 +343,7 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): hidden_states = hidden_states.to_padded_tensor(0.0) return (hidden_states,) + class MBartEncoderLayerBetterTransformer(BetterTransformerBaseLayer): def __init__(self, mbart_layer, config): r""" @@ -452,6 +453,7 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): hidden_states = hidden_states.to_padded_tensor(0.0) return (hidden_states,) + class DistilBertLayerBetterTransformer(BetterTransformerBaseLayer): def __init__(self, bert_layer, config): r"""