From 0ec4d8182f199babc726e5ca3349acd40fc9031c Mon Sep 17 00:00:00 2001 From: Ian Date: Mon, 17 Jul 2023 01:17:02 +0000 Subject: [PATCH] Update conditionals for dynamic scaling --- .../models/custom_modeling/flash_llama_modeling.py | 2 +- .../models/custom_modeling/flash_neox_modeling.py | 9 ++------- .../models/custom_modeling/flash_rw_modeling.py | 6 +----- .../models/custom_modeling/neox_modeling.py | 8 +------- server/text_generation_server/utils/layers.py | 5 ++--- 5 files changed, 7 insertions(+), 23 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 28e6f4847b0..36b73e102f6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -115,7 +115,7 @@ def __init__( self.scale_factor = ROPE_SCALE_FACTOR self.dynamic_scaling = ROPE_DYNAMIC_SCALING - if self.scale_factor > 1: + if self.scale_factor > 1 or self.dynamic_scaling: # Base before scaling is 10000 per the original RoPE paper self.rotary_emb = PositionRotaryEmbedding.static( self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 36660a22a99..f86e03442a0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -45,13 +45,8 @@ get_linear, ) - ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1)) - -if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true": - ROPE_DYNAMIC_SCALING = True -else: - ROPE_DYNAMIC_SCALING = False +ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true" def load_row(config, prefix: str, weights, bias: bool): @@ -114,7 +109,7 @@ def __init__(self, config, prefix, weights): self.scale_factor = ROPE_SCALE_FACTOR self.dynamic_scaling = ROPE_DYNAMIC_SCALING - if self.scale_factor > 1: + if self.scale_factor > 1 or self.dynamic_scaling: # Base before scaling is 10000 per the original RoPE paper self.rotary_emb = PositionRotaryEmbedding.static( self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling, config.max_position_embeddings diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 1e9915ccff5..aaf8a3c4a24 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -26,11 +26,7 @@ ) ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1)) - -if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true": - ROPE_DYNAMIC_SCALING = True -else: - ROPE_DYNAMIC_SCALING = False +ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true" def load_row(config, prefix: str, weights, bias: bool): weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index b0eef460e09..7c359026765 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -60,14 +60,8 @@ if not CUSTOM_KERNELS_ENABLED: logger.warning("We're not using custom kernels.") - ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1)) - -if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true": - ROPE_DYNAMIC_SCALING = True -else: - ROPE_DYNAMIC_SCALING = False - +ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true" def make_causal_mask( input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 1e4cd4a4e52..ea3748ff035 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -423,10 +423,9 @@ def _update_cos_sin_cache(self, dtype, device, seqlen): if self.dynamic_scaling: scale_factor = (self.scale_factor * length / self.original_max_seq_len) - (self.scale_factor - 1) max_seq_len = self.original_max_seq_len * scale_factor - inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor) - self.register_buffer("inv_freq", inv_freq) + self.inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor) - if self.scale_factor > 1: + if self.scale_factor > 1 and not self.dynamic_scaling: length = max(seqlen, max_seq_len) if (