From e3a4bd2bee212a2d0fd9f03b27fe7bfc1debe42d Mon Sep 17 00:00:00 2001 From: Shijie Wu Date: Wed, 20 Sep 2023 12:51:56 -0400 Subject: [PATCH] add custom RMSNorm to `ALL_LAYERNORM_LAYERS` (#26227) * add LlamaRMSNorm to ALL_LAYERNORM_LAYERS * fixup * add IdeficsRMSNorm to ALL_LAYERNORM_LAYERS and fixup --- src/transformers/models/idefics/modeling_idefics.py | 6 +++++- src/transformers/models/llama/modeling_llama.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index b52b7d5f93b..847170143c8 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -31,6 +31,7 @@ from ...activations import ACT2FN from ...modeling_outputs import ModelOutput from ...modeling_utils import PretrainedConfig +from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -261,7 +262,7 @@ def freeze_model(model, module_exceptions=[]): } module_exceptions_mapped = [mapping[m] for m in module_exceptions] for module in model.modules(): - if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]): + if module_exceptions and any(isinstance(module, t) for t in module_exceptions_mapped): module.requires_grad_(True) # Explicitely setting it to true to avoid any mistakes else: module.requires_grad_(False) @@ -496,6 +497,9 @@ def forward(self, hidden_states): return self.weight * hidden_states +ALL_LAYERNORM_LAYERS.append(IdeficsRMSNorm) + + # this was adapted from LlamaRotaryEmbedding class IdeficsEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f2fef00f7c1..317a788869e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_llama import LlamaConfig @@ -89,6 +90,9 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + class LlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__()