Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Rope scaling. #741

Merged
merged 2 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ impl std::fmt::Display for Dtype {
}
}

#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
Linear,
Dynamic,
}

impl std::fmt::Display for RopeScaling {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
RopeScaling::Linear => {
write!(f, "linear")
}
RopeScaling::Dynamic => {
write!(f, "dynamic")
}
}
}
}

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
Expand Down Expand Up @@ -250,6 +270,26 @@ struct Args {
#[clap(default_value = "1.0", long, env)]
cuda_memory_fraction: f32,

/// Rope scaling will only be used for RoPE models
/// and allow rescaling the position rotary to accomodate for
/// larger prompts.
///
/// Goes together with `rope_factor`.
///
/// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
/// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
/// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
/// basically)
///
/// `--rope-scaling linear --rope-factor` fully describes the scaling you want
#[clap(long, env)]
rope_scaling: Option<RopeScaling>,

/// Rope scaling will only be used for RoPE models
/// See `rope_scaling`
#[clap(long, env)]
rope_factor: Option<f32>,

/// Outputs the logs in JSON format (useful for telemetry)
#[clap(long, env)]
json_output: bool,
Expand Down Expand Up @@ -305,6 +345,8 @@ fn shard_manager(
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
Expand Down Expand Up @@ -358,6 +400,12 @@ fn shard_manager(
shard_args.push(revision)
}

let rope = match (rope_scaling, rope_factor) {
(None, None) => None,
(Some(scaling), None) => Some((scaling, 1.0)),
(Some(scaling), Some(factor)) => Some((scaling, factor)),
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
};
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string());
Expand Down Expand Up @@ -395,6 +443,15 @@ fn shard_manager(
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};

// Detect rope scaling
// Sending as env instead of CLI args to not bloat everything
// those only can be used by RoPE models, so passing information around
// for all models will complexify code unnecessarily
if let Some((scaling, factor)) = rope {
envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
}

// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
Expand Down Expand Up @@ -784,6 +841,8 @@ fn spawn_shards(
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
thread::spawn(move || {
shard_manager(
model_id,
Expand All @@ -802,6 +861,8 @@ fn spawn_shards(
watermark_gamma,
watermark_delta,
cuda_memory_fraction,
rope_scaling,
rope_factor,
otlp_endpoint,
status_sender,
shutdown,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(
self.head_size = self.hidden_size // self.num_heads

self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
)

self.softmax_scale = self.head_size**-0.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, config, prefix, weights):
self.num_heads = self.num_heads // weights.process_group.size()

self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
)

self.softmax_scale = self.head_size ** (-0.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
self.head_size = self.hidden_size // self.num_heads

self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)

Expand Down Expand Up @@ -247,7 +247,7 @@ def __init__(
self.head_size = hidden_size // num_heads

self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)

Expand Down
86 changes: 76 additions & 10 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,33 +381,65 @@ def forward(self, hidden_states, residual=None):
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb

def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return inv_freq

def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
return rope_scaling
return getattr(config, "rope_scaling", None)

class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
def __init__(self, inv_freq, scaling_factor):
super().__init__()

self.inv_freq = inv_freq
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None

@classmethod
def static(cls, dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return cls(inv_freq)
def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)

@classmethod
def load(cls, prefix, weights):
def load(cls, config, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return cls(inv_freq)

scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)

def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
Expand All @@ -419,8 +451,11 @@ def _update_cos_sin_cache(self, dtype, device, seqlen):
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)

freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
Expand All @@ -446,5 +481,36 @@ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x

class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base

def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
newbase = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)

freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)


except ImportError:
pass