diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c6679fa2f29428..af81be33922e61 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1286,6 +1286,7 @@ ] ) _import_structure["modeling_outputs"] = [] + _import_structure["modeling_rope_utils"] = [] _import_structure["modeling_utils"] = ["PreTrainedModel"] # PyTorch models structure diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 3853d8d2c8f34d..f9f658c3ac2800 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -13,17 +13,19 @@ # limitations under the License. import math -from typing import Any, Dict, Set +from typing import Any, Dict, Optional import torch +from .configuration_utils import PretrainedConfig + ROPE_CONFIG_DOCSTRING = r""" rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. When using this flag, don't update `max_position_embeddings` to the expected new maximum. Expected contents: `type` (`str`): - The scaling strategy to use. Can be one of ['linear', 'dynamic', 'yarn']. + The scaling strategy to use. Can be one of ['linear', 'dynamic', 'yarn', 'llama3']. `factor` (`float`): The scaling factor to apply to the RoPE embeddings. Must be a float greater than 1. `attention_factor` (`float`, *optional*): @@ -38,108 +40,35 @@ """ -def rope_config_validation(rope_scaling): - """ - Validate the `rope_scaling` configuration. - """ - if rope_scaling is None: - return - - if not isinstance(rope_scaling, dict) or len(rope_scaling) < 2: - raise ValueError( - "`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, " - f"got {rope_scaling}" - ) - rope_scaling_type = rope_scaling.get("type", None) - rope_scaling_factor = rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn"]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") - - if rope_scaling_type != "yarn": - return - - if not isinstance(rope_scaling, dict) or len(rope_scaling) > 6: - raise ValueError( - "`rope_scaling` with type " - f"{rope_scaling_type}" - " must be a dictionary with a maximum of six fields, `type`, `factor`," - "`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, " - f"got {rope_scaling}" - ) - original_max_position_embeddings = rope_scaling.get("original_max_position_embeddings", None) - attention_factor = rope_scaling.get("attention_factor", None) - beta_fast = rope_scaling.get("beta_fast", None) - beta_slow = rope_scaling.get("beta_slow", None) - - if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): - raise ValueError( - "`rope_scaling`'s original_max_position_embeddings field must be an int, got " - f"{original_max_position_embeddings}" - ) - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - raise ValueError( - f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - if beta_fast is not None and not isinstance(beta_fast, float): - raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - if beta_slow is not None and not isinstance(beta_slow, float): - raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - - b_fast = beta_fast if beta_fast is not None else 32 - b_slow = beta_slow if beta_slow is not None else 1 - if b_fast < b_slow: - raise ValueError( - f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and " - f"beta_slow={b_slow}" - ) - - -def _check_rope_config_keys(rope_config: Dict[str, Any], required_keys: Set, permitted_keys: Set): - """Check if the keys in the RoPE config are valid""" - keys_in_rope_config = set(rope_config.keys()) - required_keys_not_in_config = required_keys - keys_in_rope_config - if len(required_keys_not_in_config) > 0: - raise ValueError( - f"Missing required keys '{required_keys_not_in_config}' in the (internally prepared) RoPE config." - ) - all_permitted_keys = permitted_keys + required_keys - keys_not_permitted = keys_in_rope_config - all_permitted_keys - if len(keys_not_permitted) > 0: - raise ValueError(f"Unrecognized keys '{keys_not_permitted}' in the (internally prepared) RoPE config.") - - -def _compute_default_frequencies(rope_config: Dict[str, Any], device: torch.device) -> torch.Tensor: +def _compute_default_frequencies( + config: PretrainedConfig, device: torch.device, seq_len: Optional[int] +) -> torch.Tensor: """Computes the inverse frequencies according to the original RoPE implementation""" - required_keys = {"base", "dim"} - permitted_keys = {"type", "max_position_embeddings"} - _check_rope_config_keys(rope_config, required_keys, permitted_keys) - - base = rope_config["base"] - dim = rope_config["dim"] + base = config.rope_theta + if hasattr(config, "head_dim"): # TODO (joao): BC -- remove `if` in v4.45, keep `else` + dim = config.head_dim + else: + dim = config.hidden_size // config.num_attention_heads # Compute the inverse frequencies inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) return inv_freq -def _compute_dynamic_ntk_frequencies(rope_config: Dict[str, Any], device: torch.device) -> torch.Tensor: - """Computes he inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - required_keys = {"base", "dim", "scaling_factor", "max_position_embeddings"} - permitted_keys = {"type"} - _check_rope_config_keys(rope_config, required_keys, permitted_keys) - - base = rope_config["base"] - dim = rope_config["dim"] - scaling_factor = rope_config["scaling_factor"] - max_position_embeddings = rope_config["max_position_embeddings"] +def _compute_dynamic_ntk_frequencies( + config: PretrainedConfig, device: torch.device, seq_len: Optional[int] +) -> torch.Tensor: + """Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + base = config.rope_theta + if hasattr(config, "head_dim"): # TODO (joao): BC -- remove `if` in v4.45, keep `else` + dim = config.head_dim + else: + dim = config.hidden_size // config.num_attention_heads + scaling_factor = config.rope_scaling["factor"] + max_position_embeddings = config.max_position_embeddings - # Optional config options # seq_len: default to max_position_embeddings, e.g. at init time - seq_len = rope_config.get("seq_len") or max_position_embeddings + seq_len = seq_len if seq_len is not None else max_position_embeddings # Compute the inverse frequencies base = base * ((scaling_factor * seq_len / max_position_embeddings) - (scaling_factor - 1)) ** (dim / (dim - 2)) @@ -147,33 +76,31 @@ def _compute_dynamic_ntk_frequencies(rope_config: Dict[str, Any], device: torch. return inv_freq -def _compute_yarn_frequencies(rope_config: Dict[str, Any], device: torch.device) -> torch.Tensor: +def _compute_yarn_frequencies(config: PretrainedConfig, device: torch.device, seq_len: Optional[int]) -> torch.Tensor: """ - Computes he inverse frequencies with NTK scaling. Please refer to the + Computes the inverse frequencies with NTK scaling. Please refer to the [original paper](https://arxiv.org/abs/2309.00071) """ - required_keys = {"base", "dim", "scaling_factor", "max_position_embeddings"} - permitted_keys = {"type", "beta_fast", "beta_slow", "attention_factor"} - _check_rope_config_keys(rope_config, required_keys, permitted_keys) - - base = rope_config["base"] - dim = rope_config["dim"] - scaling_factor = rope_config["scaling_factor"] - max_position_embeddings = rope_config["max_position_embeddings"] + base = config.rope_theta + if hasattr(config, "head_dim"): # TODO (joao): BC -- remove `if` in v4.45, keep `else` + dim = config.head_dim + else: + dim = config.hidden_size // config.num_attention_heads + scaling_factor = config.rope_scaling["factor"] + max_position_embeddings = config.max_position_embeddings # Optional config options # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) - beta_fast = rope_config.get("beta_fast") or 32 - beta_slow = rope_config.get("beta_slow") or 1 + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 # Compute the inverse frequencies - - # Inverse dimension formula to find the dimension based on the number of rotations def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - # Find dimension range bounds based on rotations def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + """Find dimension range bounds based on rotations""" low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) @@ -203,19 +130,77 @@ def linear_ramp_mask(min, max, dim): # new RoPE types ROPE_TYPE_TO_FUNCTION = { "default": _compute_default_frequencies, + "linear": _compute_default_frequencies, # linear is the same as default, scaling is applied in `position_ids` "dynamic": _compute_dynamic_ntk_frequencies, "yarn": _compute_yarn_frequencies, } -def compute_frequencies(rope_config: Dict[str, Any], device: torch.device) -> torch.Tensor: - rope_type = rope_config.get("type", "default") +def compute_frequencies(config: PretrainedConfig, device: torch.device, seq_len: Optional[int] = None) -> torch.Tensor: + """ + Computes RoPE's inverse frequencies, given the model config. Depending on the parameterization, different + RoPE initialization or scaling strategies are used. + """ + rope_type = config.rope_scaling["type"] if config.rope_scaling is not None else "default" rope_fn = ROPE_TYPE_TO_FUNCTION.get(rope_type) if rope_fn is None: raise ValueError( f"Unrecognized RoPE type: {rope_type}.\n\nIf you want to use custom RoPE frequencies, there are two " - "options: 1: Compute RoPE (cos, sin) externally, passing it through `position_embeddings` to the model's " - "forward method. 2: Update the inverse frequencies in RoPE, updating `ROPE_TYPE_TO_FUNCTION` with " - "{'your_rope_type': Callable[rope_config, device] -> torch.Tensor}." + "options:\n- 1 Compute RoPE (cos, sin) externally, passing it through `position_embeddings` to the model's " + "forward method\n- 2: Update the inverse frequencies in RoPE, updating `ROPE_TYPE_TO_FUNCTION` with " + "{'your_rope_type': your_callable}. your_callable should take `config`, `device`, and `seq_len` and " + "return the inverse frequencies (tensor)." ) - return rope_fn(rope_config, device) + return rope_fn(config, device, seq_len) + + +def rope_config_validation(rope_scaling: Optional[Dict[str, Any]]): + """ + Validate the `rope_scaling` config argument. + """ + if rope_scaling is None: + return + + required_keys = {"type", "factor"} + received_keys = set(rope_scaling.keys()) + + missing_keys = required_keys - received_keys + if missing_keys: + raise ValueError(f"Missing required keys in `rope_scaling`: {missing_keys}") + + rope_type = rope_scaling["type"] + possible_rope_types = set(ROPE_TYPE_TO_FUNCTION.keys()) + if rope_type is None or rope_type not in possible_rope_types: + raise ValueError(f"`rope_scaling`'s 'type' field must be one of {possible_rope_types}, got {rope_type}") + + scaling_factor = rope_scaling["factor"] + if scaling_factor is None or not isinstance(scaling_factor, float) or scaling_factor < 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {scaling_factor}") + + if rope_type in ("linear", "dynamic", "llama3"): + unused_keys = received_keys - received_keys + if unused_keys: + raise ValueError(f"Unrecognized keys in `rope_scaling` for 'type'='{rope_type}': {unused_keys}") + else: # yarn + optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + unused_keys = received_keys - required_keys - optional_keys + if unused_keys: + raise ValueError(f"Unrecognized keys in `rope_scaling` for 'type'='yarn': {unused_keys}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + raise ValueError( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + raise ValueError( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index ab628992202689..54a139bae970f0 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -93,47 +93,66 @@ def forward(self, hidden_states): class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, **kwargs): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + config: Optional[LlamaConfig] = None, + **kwargs, + ): super().__init__() - self.rope_config = { - "base": base, - "dim": dim, - "max_position_embeddings": max_position_embeddings, - "scaling_factor": scaling_factor, - } - self.rope_config.update(kwargs) - # BC: in absence of "rope_type" in kwargs, set it to "default" - self.rope_config["rope_type"] = self.rope_config.get("rope_type", "default") - - inv_freq = compute_frequencies(self.rope_config, device) + # TODO (joao): remove this `if` in v4.45; the legacy args rebuild a config to power the rest of the class; + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be deprecated in v4.45" + ) + config = LlamaConfig(**kwargs) + config.rope_theta = base + config.max_position_embeddings = max_position_embeddings + config.head_dim = dim # this one doesn't actually exist, will only be used in the deprecation transition + if scaling_factor == 1.0 and len(kwargs) == 0: + config.rope_scaling = None + else: + config.rope_scaling = {"type": "default", "factor": scaling_factor} + config.rope_scaling |= kwargs # may overwrite "type" + + self.config = config + self.rope_type = config.rope_scaling["type"] if config.rope_scaling is not None else "default" + self.scaling_factor = config.rope_scaling["factor"] if config.rope_scaling is not None else 1.0 + + inv_freq = compute_frequencies(config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) - self.max_seq_len_cached = max_position_embeddings + self.max_seq_len_cached = config.max_position_embeddings - # BC: dynamic NTK and yarn used `scaling_factor` as a frequency parameter, not for position_ids scaling - if "dynamic" in self.rope_config["rope_type"] or "yarn" in self.rope_config["rope_type"]: - self.rope_config["scaling_factor"] = 1.0 # Special case: on yarn, `attention_factor` has a default suggested by the paper - if "yarn" in self.rope_config["rope_type"]: - self.rope_config["attention_factor"] = self.rope_config.get( - "attention_factor", 0.1 * math.log(scaling_factor) + 1.0 - ) + if "yarn" in self.rope_type: + attention_scale_default = 0.1 * math.log(self.scaling_factor) + 1.0 + self.attention_scaling = config.rope_scaling.get("attention_factor", attention_scale_default) + else: + self.attention_scaling = 1.0 + # BC: dynamic NTK and yarn used `scaling_factor` as a frequency parameter, not for position_ids scaling + if "dynamic" in self.rope_type or "yarn" in self.rope_type: + self.scaling_factor = 1.0 def dynamic_frequency_update(self, position_ids, device): """dynamic RoPE layers need to recompute `inv_freq` when going beyond the original maximum sequence length""" seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: - inv_freq = compute_frequencies(self.rope_config | {"seq_len": seq_len}, device) + inv_freq = compute_frequencies(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len @torch.no_grad() def forward(self, x, position_ids): - if "dynamic" in self.rope_config["rope_type"]: + if "dynamic" in self.rope_type: self.dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block - if self.rope_config["scaling_factor"] != 1.0: - position_ids = position_ids.float() / self.rope_config["scaling_factor"] + position_ids = position_ids.float() / self.scaling_factor inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) @@ -146,9 +165,8 @@ def forward(self, x, position_ids): sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - if self.rope_config.get("attention_factor") is not None: - cos = cos * self.rope_config["attention_factor"] - sin = sin * self.rope_config["attention_factor"] + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -156,27 +174,25 @@ def forward(self, x, position_ids): class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - def __init__(self, **kwargs): + def __init__(self, *args, **kwargs): logger.warning_once( "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " - "`LlamaRotaryEmbedding`, which now also does linear scaling (pass in the same kwargs plus " - "`rope_type='default'`)." + "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." ) - kwargs["rope_type"] = "default" - super().__init__(**kwargs) + super().__init__(*args, **kwargs) class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - def __init__(self, **kwargs): + def __init__(self, *args, **kwargs): logger.warning_once( "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " - "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (pass in the same kwargs plus " - "`rope_type='dynamic'`)." + "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." ) - kwargs["rope_type"] = "dynamic" - super().__init__(**kwargs) + kwargs["type"] = "dynamic" + super().__init__(*args, **kwargs) def rotate_half(x): @@ -295,13 +311,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) - rope_config = self.config.get("rope_scaling", {"type": "default"}).copy() - rope_config |= { - "dim": self.head_dim, - "max_position_embeddings": self.max_position_embeddings, - "base": self.rope_theta, - } - self.rotary_emb = LlamaRotaryEmbedding(**rope_config) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) def forward( self, @@ -930,16 +940,9 @@ def __init__(self, config: LlamaConfig): [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False - rope_config = self.config.get("rope_scaling", {"type": "default"}).copy() - rope_config |= { - "dim": self.config.hidden_size // self.config.num_attention_heads, - "max_position_embeddings": self.config.max_position_embeddings, - "base": self.config.rope_theta, - } - self.rotary_emb = LlamaRotaryEmbedding(**rope_config) - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 234ffd4ef03027..d6576cc505cebf 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -106,12 +106,11 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): @torch.no_grad() # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward def forward(self, x, position_ids): - if "dynamic" in self.rope_config["rope_type"]: + if "dynamic" in self.rope_type: self.dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block - if self.rope_config["scaling_factor"] != 1.0: - position_ids = position_ids.float() / self.rope_config["scaling_factor"] + position_ids = position_ids.float() / self.scaling_factor inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) @@ -124,9 +123,8 @@ def forward(self, x, position_ids): sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - if self.rope_config.get("attention_factor") is not None: - cos = cos * self.rope_config["attention_factor"] - sin = sin * self.rope_config["attention_factor"] + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 74976e76f63f61..5ceb288ed08ae5 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -91,47 +91,66 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo class OlmoRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, **kwargs): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + config: Optional[OlmoConfig] = None, + **kwargs, + ): super().__init__() - self.rope_config = { - "base": base, - "dim": dim, - "max_position_embeddings": max_position_embeddings, - "scaling_factor": scaling_factor, - } - self.rope_config.update(kwargs) - # BC: in absence of "rope_type" in kwargs, set it to "default" - self.rope_config["rope_type"] = self.rope_config.get("rope_type", "default") - - inv_freq = compute_frequencies(self.rope_config, device) + # TODO (joao): remove this `if` in v4.45; the legacy args rebuild a config to power the rest of the class; + if config is None: + logger.warning_once( + "`OlmoRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be deprecated in v4.45" + ) + config = OlmoConfig(**kwargs) + config.rope_theta = base + config.max_position_embeddings = max_position_embeddings + config.head_dim = dim # this one doesn't actually exist, will only be used in the deprecation transition + if scaling_factor == 1.0 and len(kwargs) == 0: + config.rope_scaling = None + else: + config.rope_scaling = {"type": "default", "factor": scaling_factor} + config.rope_scaling |= kwargs # may overwrite "type" + + self.config = config + self.rope_type = config.rope_scaling["type"] if config.rope_scaling is not None else "default" + self.scaling_factor = config.rope_scaling["factor"] if config.rope_scaling is not None else 1.0 + + inv_freq = compute_frequencies(config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) - self.max_seq_len_cached = max_position_embeddings + self.max_seq_len_cached = config.max_position_embeddings - # BC: dynamic NTK and yarn used `scaling_factor` as a frequency parameter, not for position_ids scaling - if "dynamic" in self.rope_config["rope_type"] or "yarn" in self.rope_config["rope_type"]: - self.rope_config["scaling_factor"] = 1.0 # Special case: on yarn, `attention_factor` has a default suggested by the paper - if "yarn" in self.rope_config["rope_type"]: - self.rope_config["attention_factor"] = self.rope_config.get( - "attention_factor", 0.1 * math.log(scaling_factor) + 1.0 - ) + if "yarn" in self.rope_type: + attention_scale_default = 0.1 * math.log(self.scaling_factor) + 1.0 + self.attention_scaling = config.rope_scaling.get("attention_factor", attention_scale_default) + else: + self.attention_scaling = 1.0 + # BC: dynamic NTK and yarn used `scaling_factor` as a frequency parameter, not for position_ids scaling + if "dynamic" in self.rope_type or "yarn" in self.rope_type: + self.scaling_factor = 1.0 def dynamic_frequency_update(self, position_ids, device): """dynamic RoPE layers need to recompute `inv_freq` when going beyond the original maximum sequence length""" seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: - inv_freq = compute_frequencies(self.rope_config | {"seq_len": seq_len}, device) + inv_freq = compute_frequencies(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len @torch.no_grad() def forward(self, x, position_ids): - if "dynamic" in self.rope_config["rope_type"]: + if "dynamic" in self.rope_type: self.dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block - if self.rope_config["scaling_factor"] != 1.0: - position_ids = position_ids.float() / self.rope_config["scaling_factor"] + position_ids = position_ids.float() / self.scaling_factor inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) @@ -144,9 +163,8 @@ def forward(self, x, position_ids): sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - if self.rope_config.get("attention_factor") is not None: - cos = cos * self.rope_config["attention_factor"] - sin = sin * self.rope_config["attention_factor"] + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -155,28 +173,26 @@ def forward(self, x, position_ids): class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding): """OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - def __init__(self, **kwargs): + def __init__(self, *args, **kwargs): logger.warning_once( "`OlmoLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " - "`OlmoRotaryEmbedding`, which now also does linear scaling (pass in the same kwargs plus " - "`rope_type='default'`)." + "`OlmoRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." ) - kwargs["rope_type"] = "default" - super().__init__(**kwargs) + super().__init__(*args, **kwargs) # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding): """OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - def __init__(self, **kwargs): + def __init__(self, *args, **kwargs): logger.warning_once( "`OlmoDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " - "`OlmoRotaryEmbedding`, which now also does dynamic ntk scaling (pass in the same kwargs plus " - "`rope_type='dynamic'`)." + "`OlmoRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." ) - kwargs["rope_type"] = "dynamic" - super().__init__(**kwargs) + kwargs["type"] = "dynamic" + super().__init__(*args, **kwargs) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -280,13 +296,7 @@ def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) - rope_config = self.config.get("rope_scaling", {"type": "default"}).copy() - rope_config |= { - "dim": self.head_dim, - "max_position_embeddings": self.max_position_embeddings, - "base": self.rope_theta, - } - self.rotary_emb = OlmoRotaryEmbedding(**rope_config) + self.rotary_emb = OlmoRotaryEmbedding(config=self.config) def _init_rope(self): if self.config.rope_scaling is None: diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index c14067efd9889b..c7bbcedcf23067 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -51,12 +51,7 @@ LlamaModel, LlamaTokenizer, ) - from transformers.models.llama.modeling_llama import ( - LlamaDynamicNTKScalingRotaryEmbedding, - LlamaLinearScalingRotaryEmbedding, - LlamaRotaryEmbedding, - LlamaYarnScalingRotaryEmbedding, - ) + from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding class LlamaModelTester: @@ -431,9 +426,6 @@ def test_model_rope_scaling_from_config(self, scaling_type): def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads scaling_factor = 10 short_input_length = 10 long_input_length = int(config.max_position_embeddings * 1.5) @@ -446,11 +438,7 @@ def test_model_rope_scaling(self): position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE - original_rope = LlamaRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ).to(torch_device) + original_rope = LlamaRotaryEmbedding(config=config).to(torch_device) original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_long, original_sin_long = original_rope(x, position_ids_long) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) @@ -458,12 +446,8 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = LlamaLinearScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) @@ -476,12 +460,8 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = LlamaDynamicNTKScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) @@ -493,12 +473,9 @@ def test_model_rope_scaling(self): self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) # Sanity check Yarn RoPE scaling - yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) + # Scaling should be over the entire input + config.rope_scaling = {"type": "yarn", "factor": scaling_factor} + yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device) yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])