Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE: model-agnostic RoPE refactor #31999

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
cc9b82e
Add YaRN and Dynamic-YaRN RoPE Scaling Methods
mig-mfreitas May 19, 2024
fc161dd
Merge remote-tracking branch 'upstream/main' into yarn-rope-scaling
miguelm-almeida Jun 12, 2024
1044c7b
Merge remote-tracking branch 'upstream/main' into yarn-rope-scaling
miguelm-almeida Jun 16, 2024
85552b3
Refactor YaRN implementation for LLaMA
miguelm-almeida Jun 16, 2024
d84baa9
Refactor Tensor Building Logic for YaRN
miguelm-almeida Jul 10, 2024
fdea000
Merge remote-tracking branch 'upstream/main' into yarn-rope-scaling
miguelm-almeida Jul 10, 2024
a555034
remove unwanted file
gante Jul 16, 2024
472b168
tmp commit
gante Jul 16, 2024
26fd6e9
mvp?
gante Jul 16, 2024
6ea2d3c
rm yarn class
gante Jul 16, 2024
9df8a43
can set attention_factor
gante Jul 16, 2024
10dc891
a few optims
gante Jul 16, 2024
e446e64
single rope layer
gante Jul 19, 2024
cc6af77
better config
gante Jul 19, 2024
9914572
push
gante Jul 19, 2024
8befb00
push more logic to the rope fns
gante Jul 19, 2024
20962d8
dynamic can scale back
gante Jul 19, 2024
0595968
position_embeddings last
gante Jul 19, 2024
748a318
rename new rope stuff
gante Jul 19, 2024
99305b4
Merge branch 'main' into rope_refactor
gante Jul 19, 2024
2f10261
make fixup
gante Jul 19, 2024
c34ffff
chameleon
gante Jul 19, 2024
6dae958
cohere
gante Jul 19, 2024
0bcd2c1
fix gated imports
gante Jul 20, 2024
0ec8ddb
missing this one
gante Jul 20, 2024
1e41bfc
gemma (and cousins)
gante Jul 20, 2024
dffad0d
nits
gante Jul 20, 2024
5eb821b
gemma 2
gante Jul 20, 2024
61eaf7c
mistral
gante Jul 20, 2024
c720514
Olmo
gante Jul 20, 2024
d28add5
add longrope
gante Jul 20, 2024
032f662
phi3 (but not fully working)
gante Jul 21, 2024
2cba857
last model D: D: D:
gante Jul 21, 2024
205c740
moe out
gante Jul 21, 2024
5d19465
fix missing attributes in gemma2, recurrentgemma, and olmo
gante Jul 21, 2024
909b247
same rope validation and docstring everywhere
gante Jul 21, 2024
441eabb
fix olmo
gante Jul 21, 2024
b10fee4
fix name clash; enable models with partial rope
gante Jul 21, 2024
961e6ad
cohere config
gante Jul 21, 2024
0604c44
lysandre's PR comments
gante Jul 22, 2024
556e140
fast path for dynamic freq reset
gante Jul 22, 2024
9c5a40e
bc using kwargs (instead of a config)
gante Jul 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions docs/source/en/model_doc/llama.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
- [StackLLaMA: A hands-on guide to train LLaMA with RLHF](https://huggingface.co/blog/stackllama#stackllama-a-hands-on-guide-to-train-llama-with-rlhf), a blog post about how to train LLaMA to answer questions on [Stack Exchange](https://stackexchange.com/) with RLHF.

⚗️ Optimization
- A [notebook](https://colab.research.google.com/drive/1SQUXq1AMZPSLD4mk3A3swUIc6Y2dclme?usp=sharing) on how to fine-tune LLaMA model using xturing library on GPU which has limited memory. 🌎
- A [notebook](https://colab.research.google.com/drive/1SQUXq1AMZPSLD4mk3A3swUIc6Y2dclme?usp=sharing) on how to fine-tune LLaMA model using xturing library on GPU which has limited memory. 🌎

⚡️ Inference
- A [notebook](https://colab.research.google.com/github/DominguesM/alpaca-lora-ptbr-7b/blob/main/notebooks/02%20-%20Evaluate.ipynb) on how to run the LLaMA Model using PeftModel from the 🤗 PEFT library. 🌎
- A [notebook](https://colab.research.google.com/github/DominguesM/alpaca-lora-ptbr-7b/blob/main/notebooks/02%20-%20Evaluate.ipynb) on how to run the LLaMA Model using PeftModel from the 🤗 PEFT library. 🌎
- A [notebook](https://colab.research.google.com/drive/1l2GiSSPbajVyp2Nk3CFT4t3uH6-5TiBe?usp=sharing) on how to load a PEFT adapter LLaMA model with LangChain. 🌎

🚀 Deploy
- A [notebook](https://colab.research.google.com/github/lxe/simple-llama-finetuner/blob/master/Simple_LLaMA_FineTuner.ipynb#scrollTo=3PM_DilAZD8T) on how to fine-tune LLaMA model using LoRA method via the 🤗 PEFT library with intuitive UI. 🌎
- A [notebook](https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/text-generation-open-llama.ipynb) on how to deploy Open-LLaMA model for text generation on Amazon SageMaker. 🌎
- A [notebook](https://colab.research.google.com/github/lxe/simple-llama-finetuner/blob/master/Simple_LLaMA_FineTuner.ipynb#scrollTo=3PM_DilAZD8T) on how to fine-tune LLaMA model using LoRA method via the 🤗 PEFT library with intuitive UI. 🌎
- A [notebook](https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/text-generation-open-llama.ipynb) on how to deploy Open-LLaMA model for text generation on Amazon SageMaker. 🌎

## LlamaConfig

Expand All @@ -105,11 +105,15 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h

[[autodoc]] LlamaModel
- forward
- get_rope_embeddings
- set_rope_embeddings

## LlamaForCausalLM

[[autodoc]] LlamaForCausalLM
- forward
- get_rope_embeddings
- set_rope_embeddings

## LlamaForSequenceClassification

Expand Down
201 changes: 201 additions & 0 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Any, Dict

import torch


ROPE_CONFIG_DOCSTRING = r"""
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear, dynamic and yarn. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
For the `yarn` strategy, the dictionary may also contain the following fields:
`original_max_position_embeddings` (`int`, *optional*):
The original maximum sequence length. This is used to scale the RoPE embeddings.
`attention_factor` (`float`, *optional*):
The attention scaling factor. If unspecified, it defaults to `0.1 ln(s) + 1`, where `s` is the
`original_max_position_embeddings/max_position_embeddings` ratio.
`beta_fast` (`float`, *optional*):
Parameter to set the boundary for extrapolation (only) in the linear ramp function.
`beta_slow` (`float`, *optional*):
Parameter to set the boundary for interpolation (only) in the linear ramp function.
"""


def rope_config_validation(rope_scaling):
"""
Validate the `rope_scaling` configuration.
"""
if rope_scaling is None:
return

if not isinstance(rope_scaling, dict) or len(rope_scaling) < 2:
raise ValueError(
"`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, "
f"got {rope_scaling}"
)
rope_scaling_type = rope_scaling.get("type", None)
rope_scaling_factor = rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

if rope_scaling_type != "yarn":
return

if not isinstance(rope_scaling, dict) or len(rope_scaling) > 6:
raise ValueError(
"`rope_scaling` with type "
f"{rope_scaling_type}"
" must be a dictionary with a maximum of six fields, `type`, `factor`,"
"`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, "
f"got {rope_scaling}"
)
original_max_position_embeddings = rope_scaling.get("original_max_position_embeddings", None)
attention_factor = rope_scaling.get("attention_factor", None)
beta_fast = rope_scaling.get("beta_fast", None)
beta_slow = rope_scaling.get("beta_slow", None)

if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}"
)
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")

b_fast = beta_fast if beta_fast is not None else 32
b_slow = beta_slow if beta_slow is not None else 1
if b_fast < b_slow:
raise ValueError(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}"
)


def compute_frequencies(rope_config: Dict[str, Any], device: torch.device) -> torch.Tensor:
rope_type = rope_config.get("rope_type", "default")
if rope_type == "default":
return _compute_default_frequencies(rope_config, device)
elif rope_type == "dynamic":
return _compute_dynamic_ntk_frequencies(rope_config, device)
elif rope_type == "yarn":
return _compute_yarn_frequencies(rope_config, device)
else:
raise ValueError(
f"Unrecognized RoPE type: {rope_type}. If you want to use custom RoPE frequencies, use "
"`model.set_rope_embeddings()`"
)


def _compute_default_frequencies(rope_config: Dict[str, Any], device: torch.device) -> torch.Tensor:
# Mandatory config options
required_keys = ["base", "dim"]
for key in required_keys:
if key not in rope_config:
raise ValueError(f"Missing required key '{key}' in RoPE config.")

base = rope_config["base"]
dim = rope_config["dim"]

# Compute the inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq


def _compute_dynamic_ntk_frequencies(rope_config: Dict[str, Any], device: torch.device) -> torch.Tensor:
"""Credits to the Reddit users /u/bloc97 and /u/emozilla"""
# Mandatory config options
required_keys = ["base", "dim", "scaling_factor", "max_position_embeddings"]
for key in required_keys:
if key not in rope_config:
raise ValueError(f"Missing required key '{key}' in RoPE config for RoPE type = 'dynamic'.")

base = rope_config["base"]
dim = rope_config["dim"]
scaling_factor = rope_config["scaling_factor"]
max_position_embeddings = rope_config["max_position_embeddings"]

# Optional config options
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len = rope_config.get("seq_len") or max_position_embeddings

# Compute the inverse frequencies
base = base * ((scaling_factor * seq_len / max_position_embeddings) - (scaling_factor - 1)) ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq


def _compute_yarn_frequencies(rope_config: Dict[str, Any], device: torch.device) -> torch.Tensor:
# Mandatory config options
required_keys = ["base", "dim", "scaling_factor", "max_position_embeddings"]
for key in required_keys:
if key not in rope_config:
raise ValueError(f"Missing required key '{key}' in RoPE config for RoPE type = 'dynamic'.")

base = rope_config["base"]
dim = rope_config["dim"]
scaling_factor = rope_config["scaling_factor"]
max_position_embeddings = rope_config["max_position_embeddings"]

# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
beta_fast = rope_config.get("beta_fast") or 32
beta_slow = rope_config.get("beta_slow") or 1

# Compute the inverse frequencies

# Inverse dimension formula to find the dimension based on the number of rotations
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))

# Find dimension range bounds based on rotations
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)

def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity

linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func

pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)

# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device)
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask

return inv_freq
1 change: 0 additions & 1 deletion src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ def __init__(self, config: FalconConfig):
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1

# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = FalconRotaryEmbedding(
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/fuyu/configuration_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def __init__(
**kwargs,
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/gpt_neox/configuration_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def __init__(
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
)

# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,6 +1812,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
Expand All @@ -1832,6 +1833,7 @@ def forward(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
position_embeddings=position_embeddings,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
Expand All @@ -1540,6 +1541,7 @@ def forward(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
position_embeddings=position_embeddings,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
Expand Down
35 changes: 5 additions & 30 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
"""LLaMA model configuration"""

from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import ROPE_CONFIG_DOCSTRING, rope_config_validation
from ...utils import logging


logger = logging.get_logger(__name__)


class LlamaConfig(PretrainedConfig):
r"""
rf"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.
Expand Down Expand Up @@ -83,14 +84,7 @@ class LlamaConfig(PretrainedConfig):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
{ROPE_CONFIG_DOCSTRING}
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
Expand Down Expand Up @@ -158,35 +152,16 @@ def __init__(
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias

rope_config_validation(rope_scaling)

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
Loading
Loading