Skip to content

Commit

Permalink
Falcon: Add RoPE scaling (huggingface#25878)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and EduardoPach committed Nov 18, 2023
1 parent ddc306d commit b442c9f
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,14 @@ def _rope_scaling_validation(self):

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
44 changes: 44 additions & 0 deletions src/transformers/models/falcon/configuration_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ class FalconConfig(PretrainedConfig):
instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
bias (`bool`, *optional*, defaults to `False`):
Whether to use bias on Linear layers.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with, when `alibi` is `False`. Pretrained
Falcon models with RoPE support up to 2048 tokens.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
bos_token_id (`int`, *optional*, defaults to 11):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 11):
Expand Down Expand Up @@ -111,6 +124,9 @@ def __init__(
multi_query=True,
parallel_attn=True,
bias=False,
max_position_embeddings=2048,
rope_theta=10000.0,
rope_scaling=None,
bos_token_id=11,
eos_token_id=11,
**kwargs,
Expand All @@ -135,6 +151,10 @@ def __init__(
self.multi_query = multi_query # Ignored when new_decoder_architecture is True
self.parallel_attn = parallel_attn
self.bias = bias
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

Expand All @@ -145,3 +165,27 @@ def head_dim(self):
@property
def rotary(self):
return not self.alibi

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if self.rotary:
raise ValueError("`rope_scaling` is not supported when `alibi` is `True`.")

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
125 changes: 109 additions & 16 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,32 +71,36 @@ class FalconRotaryEmbedding(nn.Module):
n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
"""

def __init__(self, head_dim: int, base=10000):
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.base = base
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (self.base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.head_dim = head_dim
self.seq_len_cached = -1
self.cos_cached: torch.Tensor | None = None
self.sin_cached: torch.Tensor | None = None

def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached:
self.seq_len_cached = total_length
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)

if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()

self.cos_cached = emb.cos()[None, :, :]
self.sin_cached = emb.sin()[None, :, :]
self.cos_cached = emb.cos()[None, :, :]
self.sin_cached = emb.sin()[None, :, :]

self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)

def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached:
self._set_cos_sin_cache(total_length, device, dtype)
return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
Expand All @@ -108,6 +112,66 @@ def forward(self, query, key, past_key_values_length=0):
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)


class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(head_dim, base, max_position_embeddings)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)

if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()

self.cos_cached = emb.cos()[None, :, :]
self.sin_cached = emb.sin()[None, :, :]

self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)


class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
"""
FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
"""

def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(head_dim, base, max_position_embeddings)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len

# This if block is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.head_dim / (self.head_dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)

if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()

self.cos_cached = emb.cos()[None, :, :]
self.sin_cached = emb.sin()[None, :, :]

self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)


def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
Expand Down Expand Up @@ -191,6 +255,7 @@ class FalconAttention(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()

self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
Expand All @@ -203,7 +268,7 @@ def __init__(self, config: FalconConfig):
f" {self.num_heads})."
)

self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t: (q, k)

# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
Expand All @@ -221,6 +286,34 @@ def __init__(self, config: FalconConfig):
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1

def _init_rope(self):
if self.config.rope_scaling is None:
rotary_emb = FalconRotaryEmbedding(
self.head_dim,
base=self.config.rope_theta,
max_position_embeddings=self.config.max_position_embeddings,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = FalconLinearScalingRotaryEmbedding(
self.head_dim,
base=self.config.rope_theta,
max_position_embeddings=self.config.max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
rotary_emb = FalconDynamicNTKScalingRotaryEmbedding(
self.head_dim,
base=self.config.rope_theta,
max_position_embeddings=self.config.max_position_embeddings,
scaling_factor=scaling_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return rotary_emb

def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gpt_neox/configuration_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ def _rope_scaling_validation(self):

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
4 changes: 2 additions & 2 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,14 @@ def _rope_scaling_validation(self):

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
35 changes: 34 additions & 1 deletion tests/models/falcon/test_modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import unittest

from transformers import AutoTokenizer, FalconConfig, is_torch_available
from parameterized import parameterized

from transformers import AutoTokenizer, FalconConfig, is_torch_available, set_seed
from transformers.testing_utils import require_torch, slow, torch_device

from ...generation.test_utils import GenerationTesterMixin
Expand Down Expand Up @@ -410,6 +412,37 @@ def test_past_key_values_format(self):
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)

@parameterized.expand([("linear",), ("dynamic",)])
def test_model_rope_scaling(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)

set_seed(42) # Fixed seed at init time so the two models get the same random weights
original_model = FalconModel(config)
original_model.to(torch_device)
original_model.eval()
original_short_output = original_model(short_input).last_hidden_state
original_long_output = original_model(long_input).last_hidden_state

set_seed(42) # Fixed seed at init time so the two models get the same random weights
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
scaled_model = FalconModel(config)
scaled_model.to(torch_device)
scaled_model.eval()
scaled_short_output = scaled_model(short_input).last_hidden_state
scaled_long_output = scaled_model(long_input).last_hidden_state

# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# maximum sequence length, so the outputs for the short input should match.
if scaling_type == "dynamic":
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
else:
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))

# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))


@require_torch
class FalconLanguageGenerationTest(unittest.TestCase):
Expand Down

0 comments on commit b442c9f

Please sign in to comment.