From f01c11bd0cc47419f9c8a2bd984b45e7e4a2c4e4 Mon Sep 17 00:00:00 2001 From: Ian Date: Sat, 1 Jul 2023 02:18:03 +0000 Subject: [PATCH] Implement scaled and dynamically scaled RoPE --- launcher/src/main.rs | 28 ++++++++++- router/src/main.rs | 4 +- .../custom_modeling/flash_llama_modeling.py | 21 ++++++-- .../custom_modeling/flash_neox_modeling.py | 23 +++++++-- .../custom_modeling/flash_rw_modeling.py | 17 ++++++- .../models/custom_modeling/neox_modeling.py | 8 +++ server/text_generation_server/utils/layers.py | 50 +++++++++++++++---- 7 files changed, 129 insertions(+), 22 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d690a7c4931..4f6fcf0223b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -162,7 +162,7 @@ struct Args { /// Limits the number of tokens for the prefill operation. /// Since this operation take the most memory and is compute bound, it is interesting /// to limit the number of requests that can be sent. - #[clap(default_value = "4096", long, env)] + #[clap(default_value = "2048", long, env)] max_batch_prefill_tokens: u32, /// **IMPORTANT** This is one critical control to allow maximum usage @@ -182,7 +182,7 @@ struct Args { /// depends on other parameters like if you're using quantization, flash attention /// or the model implementation, text-generation-inference cannot infer this number /// automatically. - #[clap(default_value = "16000", long, env)] + #[clap(default_value = "8192", long, env)] max_batch_total_tokens: u32, /// This setting defines how many tokens can be passed before forcing the waiting @@ -280,6 +280,19 @@ struct Args { /// Display a lot of information about your runtime environment #[clap(long, short, action)] env: bool, + + /// NTK-Aware Scaled Rope is a method proposed in https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + /// The scale factor, or "α", is used in combination with a non linearity to scale the base used to calculate the parameter "θ", the angle of rotation in RoPE. + /// This increases how many input tokens can be represented within the same portion of a positional embedding, with the non linearity used to increase token seprability. + #[clap(default_value="1", long, env)] + rope_scale_factor: usize, + + /// Dynamic scaling of the "α" factor in NTK-Aware Scaled Rope was introduced in https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/ + /// The idea being instead of setting alpha statically, it is calculated as a function of the current sequence length and the model's base sequence length. + /// This is a means to both increase performance on shorter sequence lengths and smooth the perplexity explosion experienced by both linearly scaled and NTK-Aware scaled RoPE. + /// If this is enabled the above "rope_scale_factor" will be ignored. + #[clap(default_value="false", long, env)] + rope_dynamic_scaling: bool } #[derive(Debug)] @@ -293,6 +306,8 @@ fn shard_manager( model_id: String, revision: Option, quantize: Option, + rope_scale_factor: usize, + rope_dynamic_scaling: bool, dtype: Option, trust_remote_code: bool, uds_path: String, @@ -422,6 +437,10 @@ fn shard_manager( envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) } + // RoPE Scaling + env.push(("ROPE_SCALE_FACTOR".into(), rope_scale_factor.to_string().into())); + env.push(("ROPE_DYNAMIC_SCALING".into(), rope_dynamic_scaling.to_string().into())); + // Start process tracing::info!("Starting shard {rank}"); let mut p = match Command::new("text-generation-server") @@ -776,11 +795,16 @@ fn spawn_shards( let disable_custom_kernels = args.disable_custom_kernels; let watermark_gamma = args.watermark_gamma; let watermark_delta = args.watermark_delta; + let rope_scale_factor = args.rope_scale_factor; + let rope_dynamic_scaling = args.rope_dynamic_scaling; + thread::spawn(move || { shard_manager( model_id, revision, quantize, + rope_scale_factor, + rope_dynamic_scaling, dtype, trust_remote_code, uds_path, diff --git a/router/src/main.rs b/router/src/main.rs index 178c249c324..d1a5c472dfb 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -35,9 +35,9 @@ struct Args { max_total_tokens: usize, #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, - #[clap(default_value = "4096", long, env)] + #[clap(default_value = "2048", long, env)] max_batch_prefill_tokens: u32, - #[clap(default_value = "16000", long, env)] + #[clap(default_value = "8192", long, env)] max_batch_total_tokens: u32, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d9f3c7b83de..28e6f4847b0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import torch.distributed @@ -41,6 +42,12 @@ TensorParallelHead, ) +ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1)) + +if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true": + ROPE_DYNAMIC_SCALING = True +else: + ROPE_DYNAMIC_SCALING = False class LlamaRMSNorm(nn.Module): def __init__(self, prefix, weights, eps=1e-6): @@ -105,10 +112,18 @@ def __init__( self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads + self.scale_factor = ROPE_SCALE_FACTOR + self.dynamic_scaling = ROPE_DYNAMIC_SCALING - self.rotary_emb = PositionRotaryEmbedding.load( - prefix=f"{prefix}.rotary_emb", weights=weights - ) + if self.scale_factor > 1: + # Base before scaling is 10000 per the original RoPE paper + self.rotary_emb = PositionRotaryEmbedding.static( + self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling + ) + else: + self.rotary_emb = PositionRotaryEmbedding.load( + prefix=f"{prefix}.rotary_emb", weights=weights + ) self.softmax_scale = self.head_size**-0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b2dce226d82..36660a22a99 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import torch.distributed @@ -45,6 +46,14 @@ ) +ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1)) + +if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true": + ROPE_DYNAMIC_SCALING = True +else: + ROPE_DYNAMIC_SCALING = False + + def load_row(config, prefix: str, weights, bias: bool): weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) @@ -102,10 +111,18 @@ def __init__(self, config, prefix, weights): f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() + self.scale_factor = ROPE_SCALE_FACTOR + self.dynamic_scaling = ROPE_DYNAMIC_SCALING - self.rotary_emb = PositionRotaryEmbedding.load( - prefix=f"{prefix}.rotary_emb", weights=weights - ) + if self.scale_factor > 1: + # Base before scaling is 10000 per the original RoPE paper + self.rotary_emb = PositionRotaryEmbedding.static( + self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling, config.max_position_embeddings + ) + else: + self.rotary_emb = PositionRotaryEmbedding.load( + prefix=f"{prefix}.rotary_emb", weights=weights + ) self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index acac27446f7..1e9915ccff5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,5 +1,7 @@ +import os import torch import torch.distributed +import warnings from torch import nn from transformers.modeling_utils import PreTrainedModel @@ -23,6 +25,12 @@ get_linear, ) +ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1)) + +if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true": + ROPE_DYNAMIC_SCALING = True +else: + ROPE_DYNAMIC_SCALING = False def load_row(config, prefix: str, weights, bias: bool): weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) @@ -113,10 +121,13 @@ def __init__( self.num_heads_kv = config.n_head_kv self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads + self.scale_factor = ROPE_SCALE_FACTOR + self.dynamic_scaling = ROPE_DYNAMIC_SCALING self.rotary_emb = PositionRotaryEmbedding.static( - dim=self.head_size, base=10000.0, device=weights.device + dim=self.head_size, base=10000.0, device=weights.device, scale_factor=self.scale_factor, dynamic_scaling=self.dynamic_scaling ) + self.softmax_scale = self.head_size ** (-0.5) if self.num_heads % weights.process_group.size() != 0: @@ -239,9 +250,11 @@ def __init__( self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + self.scale_factor = ROPE_SCALE_FACTOR + self.dynamic_scaling = ROPE_DYNAMIC_SCALING self.rotary_emb = PositionRotaryEmbedding.static( - self.head_size, base=10000.0, device=weights.device + dim=self.head_size, base=10000.0, device=weights.device, scale_factor=self.scale_factor, dynamic_scaling=self.dynamic_scaling ) self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 1951b171cf9..b0eef460e09 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -61,6 +61,14 @@ logger.warning("We're not using custom kernels.") +ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1)) + +if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true": + ROPE_DYNAMIC_SCALING = True +else: + ROPE_DYNAMIC_SCALING = False + + def make_causal_mask( input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int ) -> torch.BoolTensor: diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 4f65446e6e4..1e4cd4a4e52 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -369,7 +369,7 @@ def forward(self, hidden_states, residual=None): import rotary_emb class PositionRotaryEmbedding(nn.Module): - def __init__(self, inv_freq): + def __init__(self, inv_freq, scale_factor=1, dynamic_scaling=False, max_seq_len=2048, dim=None, base=None): super().__init__() self.inv_freq = inv_freq @@ -379,32 +379,62 @@ def __init__(self, inv_freq): self._cos_k_cached = None self._sin_k_cached = None - @classmethod - def static(cls, dim, base, device): - inv_freq = 1.0 / ( - base - ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) - ) - return cls(inv_freq) + self.scale_factor = scale_factor + self.dynamic_scaling = dynamic_scaling + self.original_max_seq_len = max_seq_len + self.max_seq_len = max_seq_len * scale_factor + self.dim = dim + self.base = base + @classmethod + def static(cls, dim, base, device, scale_factor=1, dynamic_scaling=False, max_seq_len=2048): + inv_freq = cls._get_inv_freq(dim, base, device, scale_factor) + return cls(inv_freq, scale_factor, dynamic_scaling, max_seq_len, dim, base) + @classmethod def load(cls, prefix, weights): # XXX: Always load this in float32 ! dtype = weights.dtype weights.dtype = torch.float32 + inv_freq = weights.get_tensor(f"{prefix}.inv_freq") weights.dtype = dtype return cls(inv_freq) + @staticmethod + def _get_inv_freq(dim, base, device, scale_factor=1): + base = base * scale_factor ** (dim / (dim-2)) + + inv_freq = 1.0 / ( + base + ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + + return inv_freq + def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) + + length = seqlen + max_seq_len = self.max_seq_len + inv_freq = self.inv_freq + + if self.dynamic_scaling: + scale_factor = (self.scale_factor * length / self.original_max_seq_len) - (self.scale_factor - 1) + max_seq_len = self.original_max_seq_len * scale_factor + inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor) + self.register_buffer("inv_freq", inv_freq) + + if self.scale_factor > 1: + length = max(seqlen, max_seq_len) + if ( - seqlen > self._seq_len_cached + length > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): - self._seq_len_cached = seqlen + self._seq_len_cached = length t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) # Don't do einsum, it converts fp32 to fp16 # freqs = torch.einsum("i,j->ij", t, self.inv_freq)