Skip to content

Commit

Permalink
Implement scaled and dynamically scaled RoPE
Browse files Browse the repository at this point in the history
  • Loading branch information
iantbutler01 committed Jul 17, 2023
1 parent a2cf1bd commit f01c11b
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 22 deletions.
28 changes: 26 additions & 2 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ struct Args {
/// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent.
#[clap(default_value = "4096", long, env)]
#[clap(default_value = "2048", long, env)]
max_batch_prefill_tokens: u32,

/// **IMPORTANT** This is one critical control to allow maximum usage
Expand All @@ -182,7 +182,7 @@ struct Args {
/// depends on other parameters like if you're using quantization, flash attention
/// or the model implementation, text-generation-inference cannot infer this number
/// automatically.
#[clap(default_value = "16000", long, env)]
#[clap(default_value = "8192", long, env)]
max_batch_total_tokens: u32,

/// This setting defines how many tokens can be passed before forcing the waiting
Expand Down Expand Up @@ -280,6 +280,19 @@ struct Args {
/// Display a lot of information about your runtime environment
#[clap(long, short, action)]
env: bool,

/// NTK-Aware Scaled Rope is a method proposed in https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
/// The scale factor, or "α", is used in combination with a non linearity to scale the base used to calculate the parameter "θ", the angle of rotation in RoPE.
/// This increases how many input tokens can be represented within the same portion of a positional embedding, with the non linearity used to increase token seprability.
#[clap(default_value="1", long, env)]
rope_scale_factor: usize,

/// Dynamic scaling of the "α" factor in NTK-Aware Scaled Rope was introduced in https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/
/// The idea being instead of setting alpha statically, it is calculated as a function of the current sequence length and the model's base sequence length.
/// This is a means to both increase performance on shorter sequence lengths and smooth the perplexity explosion experienced by both linearly scaled and NTK-Aware scaled RoPE.
/// If this is enabled the above "rope_scale_factor" will be ignored.
#[clap(default_value="false", long, env)]
rope_dynamic_scaling: bool
}

#[derive(Debug)]
Expand All @@ -293,6 +306,8 @@ fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: Option<Quantization>,
rope_scale_factor: usize,
rope_dynamic_scaling: bool,
dtype: Option<Dtype>,
trust_remote_code: bool,
uds_path: String,
Expand Down Expand Up @@ -422,6 +437,10 @@ fn shard_manager(
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}

// RoPE Scaling
env.push(("ROPE_SCALE_FACTOR".into(), rope_scale_factor.to_string().into()));
env.push(("ROPE_DYNAMIC_SCALING".into(), rope_dynamic_scaling.to_string().into()));

// Start process
tracing::info!("Starting shard {rank}");
let mut p = match Command::new("text-generation-server")
Expand Down Expand Up @@ -776,11 +795,16 @@ fn spawn_shards(
let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let rope_scale_factor = args.rope_scale_factor;
let rope_dynamic_scaling = args.rope_dynamic_scaling;

thread::spawn(move || {
shard_manager(
model_id,
revision,
quantize,
rope_scale_factor,
rope_dynamic_scaling,
dtype,
trust_remote_code,
uds_path,
Expand Down
4 changes: 2 additions & 2 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ struct Args {
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
#[clap(default_value = "2048", long, env)]
max_batch_prefill_tokens: u32,
#[clap(default_value = "16000", long, env)]
#[clap(default_value = "8192", long, env)]
max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import torch.distributed

Expand All @@ -41,6 +42,12 @@
TensorParallelHead,
)

ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))

if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
ROPE_DYNAMIC_SCALING = True
else:
ROPE_DYNAMIC_SCALING = False

class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
Expand Down Expand Up @@ -105,10 +112,18 @@ def __init__(
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING

self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
if self.scale_factor > 1:
# Base before scaling is 10000 per the original RoPE paper
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling
)
else:
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)

self.softmax_scale = self.head_size**-0.5

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import torch.distributed

Expand Down Expand Up @@ -45,6 +46,14 @@
)


ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))

if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
ROPE_DYNAMIC_SCALING = True
else:
ROPE_DYNAMIC_SCALING = False


def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

Expand Down Expand Up @@ -102,10 +111,18 @@ def __init__(self, config, prefix, weights):
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING

self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
if self.scale_factor > 1:
# Base before scaling is 10000 per the original RoPE paper
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling, config.max_position_embeddings
)
else:
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)

self.softmax_scale = self.head_size ** (-0.5)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import torch
import torch.distributed
import warnings

from torch import nn
from transformers.modeling_utils import PreTrainedModel
Expand All @@ -23,6 +25,12 @@
get_linear,
)

ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))

if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
ROPE_DYNAMIC_SCALING = True
else:
ROPE_DYNAMIC_SCALING = False

def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
Expand Down Expand Up @@ -113,10 +121,13 @@ def __init__(
self.num_heads_kv = config.n_head_kv
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING

self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
dim=self.head_size, base=10000.0, device=weights.device, scale_factor=self.scale_factor, dynamic_scaling=self.dynamic_scaling
)

self.softmax_scale = self.head_size ** (-0.5)

if self.num_heads % weights.process_group.size() != 0:
Expand Down Expand Up @@ -239,9 +250,11 @@ def __init__(

self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING

self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
dim=self.head_size, base=10000.0, device=weights.device, scale_factor=self.scale_factor, dynamic_scaling=self.dynamic_scaling
)
self.softmax_scale = self.head_size ** (-0.5)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@
logger.warning("We're not using custom kernels.")


ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))

if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
ROPE_DYNAMIC_SCALING = True
else:
ROPE_DYNAMIC_SCALING = False


def make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
Expand Down
50 changes: 40 additions & 10 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def forward(self, hidden_states, residual=None):
import rotary_emb

class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
def __init__(self, inv_freq, scale_factor=1, dynamic_scaling=False, max_seq_len=2048, dim=None, base=None):
super().__init__()

self.inv_freq = inv_freq
Expand All @@ -379,32 +379,62 @@ def __init__(self, inv_freq):
self._cos_k_cached = None
self._sin_k_cached = None

@classmethod
def static(cls, dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return cls(inv_freq)
self.scale_factor = scale_factor
self.dynamic_scaling = dynamic_scaling
self.original_max_seq_len = max_seq_len
self.max_seq_len = max_seq_len * scale_factor
self.dim = dim
self.base = base

@classmethod
def static(cls, dim, base, device, scale_factor=1, dynamic_scaling=False, max_seq_len=2048):
inv_freq = cls._get_inv_freq(dim, base, device, scale_factor)
return cls(inv_freq, scale_factor, dynamic_scaling, max_seq_len, dim, base)

@classmethod
def load(cls, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32

inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return cls(inv_freq)

@staticmethod
def _get_inv_freq(dim, base, device, scale_factor=1):
base = base * scale_factor ** (dim / (dim-2))

inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)

return inv_freq

def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)

length = seqlen
max_seq_len = self.max_seq_len
inv_freq = self.inv_freq

if self.dynamic_scaling:
scale_factor = (self.scale_factor * length / self.original_max_seq_len) - (self.scale_factor - 1)
max_seq_len = self.original_max_seq_len * scale_factor
inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor)
self.register_buffer("inv_freq", inv_freq)

if self.scale_factor > 1:
length = max(seqlen, max_seq_len)

if (
seqlen > self._seq_len_cached
length > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
self._seq_len_cached = length
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Expand Down

0 comments on commit f01c11b

Please sign in to comment.