Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add custom RMSNorm to ALL_LAYERNORM_LAYERS #26227

Merged
merged 3 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ...activations import ACT2FN
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PretrainedConfig
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
Expand Down Expand Up @@ -261,7 +262,7 @@ def freeze_model(model, module_exceptions=[]):
}
module_exceptions_mapped = [mapping[m] for m in module_exceptions]
for module in model.modules():
if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]):
if module_exceptions and any(isinstance(module, t) for t in module_exceptions_mapped):
module.requires_grad_(True) # Explicitely setting it to true to avoid any mistakes
else:
module.requires_grad_(False)
Expand Down Expand Up @@ -496,6 +497,9 @@ def forward(self, hidden_states):
return self.weight * hidden_states


ALL_LAYERNORM_LAYERS.append(IdeficsRMSNorm)


# this was adapted from LlamaRotaryEmbedding
class IdeficsEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_llama import LlamaConfig

Expand Down Expand Up @@ -89,6 +90,9 @@ def forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype)


ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)


class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon
class PersimmonRotaryEmbedding(torch.nn.Module):
class PersimmonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

Expand Down