Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix chat completion and docs #358

Merged
merged 6 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,27 @@ impl Infer {
Ok((permit, response_rx.into_stream()))
}

/// Tokenizer the input
#[instrument(skip_all)]
pub(crate) async fn tokenize(
&self,
request: GenerateRequest,
) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request
let inputs = request.inputs;
let truncate = request.parameters.truncate;
let encoding = self
.validation
.tokenize(inputs, truncate)
.await
.map_err(|err| {
tracing::error!("Error occured during tokenization. {err}");
err
})?;

// Return Encoding
Ok(encoding.map(|(encoding, _)| encoding))
}
/// Add a new request to the queue and return a InferResponse
#[instrument(skip(self))]
pub(crate) async fn generate(
Expand Down
24 changes: 24 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ impl From<CompatGenerateRequest> for GenerateRequest {
}
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct TokenizeRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String,
}

#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
#[schema(example = 0)]
Expand Down Expand Up @@ -358,6 +364,18 @@ pub struct Token {
alternative_tokens: Option<Vec<AlternativeToken>>,
}

#[derive(Debug, Serialize, ToSchema)]
pub struct SimpleToken {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(example = 0)]
start: usize,
#[schema(example = 2)]
stop: usize,
}

#[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason {
Expand Down Expand Up @@ -408,6 +426,10 @@ pub(crate) struct GenerateResponse {
pub details: Option<Details>,
}

#[derive(Serialize, ToSchema)]
#[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);

#[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails {
#[schema(example = "length")]
Expand Down Expand Up @@ -551,7 +573,9 @@ struct CompletionStreamResponse {

#[derive(Serialize, ToSchema)]
struct ChatMessage {
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
}

Expand Down
98 changes: 90 additions & 8 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::json;
use crate::validation::ValidationError;
use crate::{
BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse,
CompatGenerateRequest, CompletionRequest, CompletionResponse, CompletionStreamResponse,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
AdapterParameters, AlternativeToken, BestOfSequence, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice, ChatMessage, CompatGenerateRequest, CompletionFinishReason,
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, LogProbs,
PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse,
Token, TokenizeRequest, TokenizeResponse, UsageInfo, Validation,
};
use axum::extract::Extension;
use axum::http::{request, HeaderMap, Method, StatusCode};
Expand All @@ -28,6 +32,7 @@ use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::sync::Mutex;
use tokenizers::Tokenizer;
use tokio::signal;
use tokio::sync::mpsc;
Expand Down Expand Up @@ -108,7 +113,7 @@ async fn compat_generate(
/// OpenAI compatible completions endpoint
#[utoipa::path(
post,
tag = "LoRAX",
tag = "OpenAI Compatible",
path = "/v1/completions",
request_body = CompletionRequest,
responses(
Expand Down Expand Up @@ -190,7 +195,7 @@ async fn completions_v1(
/// OpenAI compatible chat completions endpoint
#[utoipa::path(
post,
tag = "LoRAX",
tag = "OpenAI Compatible",
path = "/v1/chat/completions",
request_body = ChatCompletionRequest,
responses(
Expand Down Expand Up @@ -899,27 +904,53 @@ pub async fn run(
compat_generate,
generate,
generate_stream,
completions_v1,
chat_completions_v1,
tokenize,
metrics,
),
components(
schemas(
Info,
UsageInfo,
ResponseFormat,
ResponseFormatType,
CompatGenerateRequest,
GenerateRequest,
GenerateParameters,
AdapterParameters,
AlternativeToken,
PrefillToken,
Token,
SimpleToken,
TokenizeRequest,
TokenizeResponse,
GenerateResponse,
BestOfSequence,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
ChatMessage,
LogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionStreamResponse,
CompletionResponseStreamChoice,
CompletionFinishReason,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
)
),
tags(
(name = "LoRAX", description = "LoRAX API")
(name = "LoRAX", description = "LoRAX API"),
(name = "OpenAI Compatible", description = "OpenAI compatible API"),
(name = "Tokenization", description = "Tokenizer API"),
),
info(
title = "LoRAX",
Expand All @@ -931,6 +962,8 @@ pub async fn run(
)]
struct ApiDoc;

let cloned_tokenizer = tokenizer.clone().map(|t| Arc::new(Mutex::new(t)));

// Create state
let validation = Validation::new(
validation_workers,
Expand Down Expand Up @@ -1087,14 +1120,16 @@ pub async fn run(
.route("/ping", get(health))
// Prometheus metrics route
.route("/metrics", get(metrics))
.route("/tokenize", post(tokenize))
.layer(Extension(info))
.layer(Extension(request_logger_sender.clone()))
.layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
.layer(Extension(prom_handle.clone()))
.layer(opentelemetry_tracing_layer())
.layer(cors_layer);
.layer(cors_layer)
.layer(Extension(cloned_tokenizer));

if ngrok {
#[cfg(feature = "ngrok")]
Expand Down Expand Up @@ -1229,3 +1264,50 @@ impl From<InferError> for Event {
.unwrap()
}
}

/// Tokenize inputs
#[utoipa::path(
post,
tag = "Tokenization",
path = "/tokenize",
request_body = TokenizeRequest,
responses(
(status = 200, description = "Tokenized ids", body = TokenizeResponse),
(status = 404, description = "No tokenizer found", body = ErrorResponse,
example = json ! ({"error": "No fast tokenizer available"})),
)
)]
#[instrument(skip_all)]
async fn tokenize(
Extension(cloned_tokenizer): Extension<Option<Arc<Mutex<Tokenizer>>>>,
Json(req): Json<TokenizeRequest>,
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
if let Some(tokenizer) = cloned_tokenizer {
let input = req.inputs.clone();
let tokenizer = tokenizer.lock().unwrap();
let char_offset = tokenizer.encode_char_offsets(&input[..], false).unwrap();
let tokens: Vec<SimpleToken> = char_offset
.get_ids()
.iter()
.zip(char_offset.get_offsets().iter())
.map(|(&id, &(start, stop))| {
let text: String = input.chars().skip(start).take(stop - start).collect();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect();
Ok(Json(TokenizeResponse(tokens)))
} else {
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
error_type: "no fast tokenizer".to_string(),
}),
))
}
}
49 changes: 30 additions & 19 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,12 @@ impl Validation {
}
}

#[instrument(skip_all)]
async fn validate_input(
#[instrument(skip(self, inputs))]
pub async fn tokenize(
&self,
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(String, usize), ValidationError> {
) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.sender {
// Create response channel
Expand All @@ -79,7 +78,24 @@ impl Validation {

// Await on response channel
// Unwrap is safe here
let (inputs, input_length) = response_receiver.await.unwrap()?;
let encoding = response_receiver.await.unwrap()?;
Ok(Some(encoding))
} else {
Ok(None)
}
}

#[instrument(skip(self, inputs))]
async fn validate_input(
&self,
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(String, usize), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel
let input_length = encoding.len();

if let Some(max_new_tokens) = max_new_tokens {
// Get total tokens
Expand Down Expand Up @@ -352,36 +368,31 @@ fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver<TokenizerReq

/// Get input length and optionally truncate it
fn prepare_input(
inputs: String,
mut inputs: String,
truncate: Option<usize>,
tokenizer: &Tokenizer,
) -> Result<(String, usize), ValidationError> {
) -> Result<(tokenizers::Encoding, String), ValidationError> {
// Get the number of tokens in the input
let mut encoding = tokenizer
.encode(inputs.clone(), true)
.encode(inputs.clone(), false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation for add_special_tokens is pretty unclear, but is there a difference in the generated input lengths when this param is set to false. The one thing I would want to double-check here is that we don't under-count the number of tokens in the input during validation, otherwise we could exceed the max positional embeddings during inference and cause segfaults.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right.
Due to a simple test during development, the default value of add_special_tokens was changed and I forgot to revert it. I'll fix it back to the original.

.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

// Optionally truncate
let (inputs, input_length) = match truncate {
// Truncate is some and < encoding length
Some(truncate) if truncate < encoding.len() => {
// truncate encoding and decode new inputs
if let Some(truncate) = truncate {
if truncate < encoding.len() {
encoding.truncate(truncate, 0, TruncationDirection::Left);
let inputs = tokenizer
inputs = tokenizer
.decode(encoding.get_ids(), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
(inputs, encoding.len())
}
// Nothing to do
_ => (inputs, encoding.len()),
};
}

Ok((inputs, input_length))
Ok((encoding, inputs))
}

type TokenizerRequest = (
(String, Option<usize>),
oneshot::Sender<Result<(String, usize), ValidationError>>,
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
Span,
);

Expand Down
Loading