Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return n choices for chat completions API #638

Merged
merged 14 commits into from
Oct 15, 2024
14 changes: 7 additions & 7 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ jobs:
security-events: write

steps:
- name: Log in to GitHub Container Registry
uses: docker/login-action@v1
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GHCR_PAT }}

- name: Checkout repository
uses: actions/checkout@v3
with:
Expand Down Expand Up @@ -108,13 +115,6 @@ jobs:
echo "Importing $image_path-$tag_hash to Containerd"
sudo ctr i import --no-unpack --all-platforms --digests $image_path-$tag_hash.tar.gz

- name: Log in to GitHub Container Registry
uses: docker/login-action@v1
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GHCR_PAT }}

- name: Push image with containerd
env:
tags: ${{ steps.meta.outputs.tags }}
Expand Down
47 changes: 36 additions & 11 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ pub struct SimpleToken {
stop: usize,
}

#[derive(Serialize, ToSchema)]
#[derive(Serialize, ToSchema, Clone)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason {
#[schema(rename = "length")]
Expand Down Expand Up @@ -886,21 +886,46 @@ impl From<GenerateResponse> for ChatCompletionResponse {
.unwrap_or(0);
let total_tokens = prompt_tokens + completion_tokens;

// assign choices as the generated text, and include the best of sequences if available
let mut choices = vec![ChatCompletionResponseChoice {
index: 0,
message: ChatMessage {
role: Some("assistant".to_string()),
content: Some(resp.generated_text),
},
finish_reason: resp
.details
.as_ref()
.map(|x| CompletionFinishReason::from(x.finish_reason.clone())),
}];

choices.extend(
resp.details
.as_ref()
.and_then(|x| x.best_of_sequences.as_ref())
.into_iter()
.flat_map(|seqs| {
seqs.iter()
.enumerate()
.map(|(index, seq)| ChatCompletionResponseChoice {
index: index as i32 + 1,
message: ChatMessage {
role: Some("assistant".to_string()),
content: Some(seq.generated_text.clone()),
},
finish_reason: Some(CompletionFinishReason::from(
seq.finish_reason.clone(),
)),
})
}),
);

ChatCompletionResponse {
id: "null".to_string(),
object: "text_completion".to_string(),
created: 0,
model: "null".to_string(),
choices: vec![ChatCompletionResponseChoice {
index: 0,
message: ChatMessage {
role: Some("assistant".to_string()),
content: Some(resp.generated_text),
},
finish_reason: resp
.details
.map(|x| CompletionFinishReason::from(x.finish_reason)),
}],
choices: choices,
usage: UsageInfo {
prompt_tokens: prompt_tokens,
total_tokens: total_tokens,
Expand Down
28 changes: 11 additions & 17 deletions router/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,6 @@ struct AdapterSchedulerState {
/// Speculation amount
speculate: u32,

/// Prefix caching
prefix_caching: bool,

/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
}
Expand Down Expand Up @@ -242,7 +239,6 @@ impl AdapterSchedulerState {
block_size,
window_size,
speculate,
prefix_caching,
block_allocator,
}
}
Expand Down Expand Up @@ -370,19 +366,17 @@ impl AdapterSchedulerState {

// If we're prefix caching, this check could be under-estimating the number of available blocks
// due to shared prefixes, so we'll let the block allocator determine whether we have enough space.
if !self.prefix_caching {
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.queues_state
.lock()
.await
.push_front(&adapter, id, entry);
break;
}
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.queues_state
.lock()
.await
.push_front(&adapter, id, entry);
break;
}

let tokens = entry.request.input_length()
Expand Down
9 changes: 8 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,17 @@ async fn chat_completions_v1(
}
};

let mut adapter_id = Some(req.model.clone());
if req.model == info.model_id.as_str() {
// Allow user to specify the base model, but treat it as an empty adapter_id
tracing::debug!("Replacing base model {0} with empty adapter_id", req.model);
adapter_id = None;
}

let mut gen_req = CompatGenerateRequest {
inputs: inputs.to_string(),
parameters: GenerateParameters {
adapter_id: req.model.parse().ok(),
adapter_id: adapter_id,
adapter_source: req.adapter_source,
adapter_parameters: None,
api_token: req.api_token,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)

if config.quantize not in ["gptq", "awq"]:
input_scale, weight_scale = None, None
if type(weight) is tuple:
weight, input_scale, weight_scale = weight

if config.quantize not in ["gptq", "awq", "fp8"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

head_size = config.hidden_size // config.num_attention_heads
Expand All @@ -183,7 +187,13 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
return TensorParallelColumnLinear(get_linear(
weight,
bias=None,
quantize=config.quantize,
weight_scale=weight_scale,
input_scale=input_scale,
))


def _load_experts(config, prefix, mat, weights):
Expand Down
Loading