Skip to content

Commit

Permalink
Return n choices for chat completions API (#638)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Oct 15, 2024
1 parent 2a22063 commit bea8834
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 38 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ jobs:
security-events: write

steps:
- name: Log in to GitHub Container Registry
uses: docker/login-action@v1
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GHCR_PAT }}

- name: Checkout repository
uses: actions/checkout@v3
with:
Expand Down Expand Up @@ -108,13 +115,6 @@ jobs:
echo "Importing $image_path-$tag_hash to Containerd"
sudo ctr i import --no-unpack --all-platforms --digests $image_path-$tag_hash.tar.gz
- name: Log in to GitHub Container Registry
uses: docker/login-action@v1
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GHCR_PAT }}

- name: Push image with containerd
env:
tags: ${{ steps.meta.outputs.tags }}
Expand Down
47 changes: 36 additions & 11 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ pub struct SimpleToken {
stop: usize,
}

#[derive(Serialize, ToSchema)]
#[derive(Serialize, ToSchema, Clone)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason {
#[schema(rename = "length")]
Expand Down Expand Up @@ -886,21 +886,46 @@ impl From<GenerateResponse> for ChatCompletionResponse {
.unwrap_or(0);
let total_tokens = prompt_tokens + completion_tokens;

// assign choices as the generated text, and include the best of sequences if available
let mut choices = vec![ChatCompletionResponseChoice {
index: 0,
message: ChatMessage {
role: Some("assistant".to_string()),
content: Some(resp.generated_text),
},
finish_reason: resp
.details
.as_ref()
.map(|x| CompletionFinishReason::from(x.finish_reason.clone())),
}];

choices.extend(
resp.details
.as_ref()
.and_then(|x| x.best_of_sequences.as_ref())
.into_iter()
.flat_map(|seqs| {
seqs.iter()
.enumerate()
.map(|(index, seq)| ChatCompletionResponseChoice {
index: index as i32 + 1,
message: ChatMessage {
role: Some("assistant".to_string()),
content: Some(seq.generated_text.clone()),
},
finish_reason: Some(CompletionFinishReason::from(
seq.finish_reason.clone(),
)),
})
}),
);

ChatCompletionResponse {
id: "null".to_string(),
object: "text_completion".to_string(),
created: 0,
model: "null".to_string(),
choices: vec![ChatCompletionResponseChoice {
index: 0,
message: ChatMessage {
role: Some("assistant".to_string()),
content: Some(resp.generated_text),
},
finish_reason: resp
.details
.map(|x| CompletionFinishReason::from(x.finish_reason)),
}],
choices: choices,
usage: UsageInfo {
prompt_tokens: prompt_tokens,
total_tokens: total_tokens,
Expand Down
28 changes: 11 additions & 17 deletions router/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,6 @@ struct AdapterSchedulerState {
/// Speculation amount
speculate: u32,

/// Prefix caching
prefix_caching: bool,

/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
}
Expand Down Expand Up @@ -242,7 +239,6 @@ impl AdapterSchedulerState {
block_size,
window_size,
speculate,
prefix_caching,
block_allocator,
}
}
Expand Down Expand Up @@ -370,19 +366,17 @@ impl AdapterSchedulerState {

// If we're prefix caching, this check could be under-estimating the number of available blocks
// due to shared prefixes, so we'll let the block allocator determine whether we have enough space.
if !self.prefix_caching {
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.queues_state
.lock()
.await
.push_front(&adapter, id, entry);
break;
}
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.queues_state
.lock()
.await
.push_front(&adapter, id, entry);
break;
}

let tokens = entry.request.input_length()
Expand Down
9 changes: 8 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,17 @@ async fn chat_completions_v1(
}
};

let mut adapter_id = Some(req.model.clone());
if req.model == info.model_id.as_str() {
// Allow user to specify the base model, but treat it as an empty adapter_id
tracing::debug!("Replacing base model {0} with empty adapter_id", req.model);
adapter_id = None;
}

let mut gen_req = CompatGenerateRequest {
inputs: inputs.to_string(),
parameters: GenerateParameters {
adapter_id: req.model.parse().ok(),
adapter_id: adapter_id,
adapter_source: req.adapter_source,
adapter_parameters: None,
api_token: req.api_token,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)

if config.quantize not in ["gptq", "awq"]:
input_scale, weight_scale = None, None
if type(weight) is tuple:
weight, input_scale, weight_scale = weight

if config.quantize not in ["gptq", "awq", "fp8"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

head_size = config.hidden_size // config.num_attention_heads
Expand All @@ -183,7 +187,13 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
return TensorParallelColumnLinear(get_linear(
weight,
bias=None,
quantize=config.quantize,
weight_scale=weight_scale,
input_scale=input_scale,
))


def _load_experts(config, prefix, mat, weights):
Expand Down

0 comments on commit bea8834

Please sign in to comment.