Skip to content

Commit

Permalink
Fix chat completion and docs (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
GirinMan authored Mar 27, 2024
1 parent e4bad4d commit 0b9117f
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 26 deletions.
21 changes: 21 additions & 0 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,27 @@ impl Infer {
Ok((permit, response_rx.into_stream()))
}

/// Tokenizer the input
#[instrument(skip_all)]
pub(crate) async fn tokenize(
&self,
request: GenerateRequest,
) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request
let inputs = request.inputs;
let truncate = request.parameters.truncate;
let encoding = self
.validation
.tokenize(inputs, truncate)
.await
.map_err(|err| {
tracing::error!("Error occured during tokenization. {err}");
err
})?;

// Return Encoding
Ok(encoding.map(|(encoding, _)| encoding))
}
/// Add a new request to the queue and return a InferResponse
#[instrument(skip(self))]
pub(crate) async fn generate(
Expand Down
26 changes: 26 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,14 @@ impl From<CompatGenerateRequest> for GenerateRequest {
}
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct TokenizeRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String,
#[schema(nullable = true, example = true)]
pub add_special_tokens: Option<bool>,
}

#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
#[schema(example = 0)]
Expand Down Expand Up @@ -358,6 +366,18 @@ pub struct Token {
alternative_tokens: Option<Vec<AlternativeToken>>,
}

#[derive(Debug, Serialize, ToSchema)]
pub struct SimpleToken {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(example = 0)]
start: usize,
#[schema(example = 2)]
stop: usize,
}

#[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason {
Expand Down Expand Up @@ -408,6 +428,10 @@ pub(crate) struct GenerateResponse {
pub details: Option<Details>,
}

#[derive(Serialize, ToSchema)]
#[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);

#[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails {
#[schema(example = "length")]
Expand Down Expand Up @@ -551,7 +575,9 @@ struct CompletionStreamResponse {

#[derive(Serialize, ToSchema)]
struct ChatMessage {
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
}

Expand Down
104 changes: 96 additions & 8 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::json;
use crate::validation::ValidationError;
use crate::{
BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse,
CompatGenerateRequest, CompletionRequest, CompletionResponse, CompletionStreamResponse,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
AdapterParameters, AlternativeToken, BestOfSequence, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice, ChatMessage, CompatGenerateRequest, CompletionFinishReason,
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, LogProbs,
PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse,
Token, TokenizeRequest, TokenizeResponse, UsageInfo, Validation,
};
use axum::extract::Extension;
use axum::http::{request, HeaderMap, Method, StatusCode};
Expand All @@ -28,6 +32,7 @@ use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::sync::Mutex;
use tokenizers::Tokenizer;
use tokio::signal;
use tokio::sync::mpsc;
Expand Down Expand Up @@ -108,7 +113,7 @@ async fn compat_generate(
/// OpenAI compatible completions endpoint
#[utoipa::path(
post,
tag = "LoRAX",
tag = "OpenAI Compatible",
path = "/v1/completions",
request_body = CompletionRequest,
responses(
Expand Down Expand Up @@ -190,7 +195,7 @@ async fn completions_v1(
/// OpenAI compatible chat completions endpoint
#[utoipa::path(
post,
tag = "LoRAX",
tag = "OpenAI Compatible",
path = "/v1/chat/completions",
request_body = ChatCompletionRequest,
responses(
Expand Down Expand Up @@ -899,27 +904,53 @@ pub async fn run(
compat_generate,
generate,
generate_stream,
completions_v1,
chat_completions_v1,
tokenize,
metrics,
),
components(
schemas(
Info,
UsageInfo,
ResponseFormat,
ResponseFormatType,
CompatGenerateRequest,
GenerateRequest,
GenerateParameters,
AdapterParameters,
AlternativeToken,
PrefillToken,
Token,
SimpleToken,
TokenizeRequest,
TokenizeResponse,
GenerateResponse,
BestOfSequence,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
ChatMessage,
LogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionStreamResponse,
CompletionResponseStreamChoice,
CompletionFinishReason,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
)
),
tags(
(name = "LoRAX", description = "LoRAX API")
(name = "LoRAX", description = "LoRAX API"),
(name = "OpenAI Compatible", description = "OpenAI compatible API"),
(name = "Tokenization", description = "Tokenizer API"),
),
info(
title = "LoRAX",
Expand All @@ -931,6 +962,8 @@ pub async fn run(
)]
struct ApiDoc;

let cloned_tokenizer = tokenizer.clone().map(|t| Arc::new(Mutex::new(t)));

// Create state
let validation = Validation::new(
validation_workers,
Expand Down Expand Up @@ -1087,14 +1120,16 @@ pub async fn run(
.route("/ping", get(health))
// Prometheus metrics route
.route("/metrics", get(metrics))
.route("/tokenize", post(tokenize))
.layer(Extension(info))
.layer(Extension(request_logger_sender.clone()))
.layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
.layer(Extension(prom_handle.clone()))
.layer(opentelemetry_tracing_layer())
.layer(cors_layer);
.layer(cors_layer)
.layer(Extension(cloned_tokenizer));

if ngrok {
#[cfg(feature = "ngrok")]
Expand Down Expand Up @@ -1229,3 +1264,56 @@ impl From<InferError> for Event {
.unwrap()
}
}

/// Tokenize inputs
#[utoipa::path(
post,
tag = "Tokenization",
path = "/tokenize",
request_body = TokenizeRequest,
responses(
(status = 200, description = "Tokenized ids", body = TokenizeResponse),
(status = 404, description = "No tokenizer found", body = ErrorResponse,
example = json ! ({"error": "No fast tokenizer available"})),
)
)]
#[instrument(skip_all)]
async fn tokenize(
Extension(cloned_tokenizer): Extension<Option<Arc<Mutex<Tokenizer>>>>,
Json(req): Json<TokenizeRequest>,
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
if let Some(tokenizer) = cloned_tokenizer {
let input = req.inputs.clone();
let add_special_tokens = match req.add_special_tokens {
None => true,
_ => req.add_special_tokens.unwrap(),
};
let tokenizer = tokenizer.lock().unwrap();
let char_offset = tokenizer
.encode_char_offsets(&input[..], add_special_tokens)
.unwrap();
let tokens: Vec<SimpleToken> = char_offset
.get_ids()
.iter()
.zip(char_offset.get_offsets().iter())
.map(|(&id, &(start, stop))| {
let text: String = tokenizer.id_to_token(id).unwrap();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect();
Ok(Json(TokenizeResponse(tokens)))
} else {
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
error_type: "no fast tokenizer".to_string(),
}),
))
}
}
47 changes: 29 additions & 18 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,12 @@ impl Validation {
}
}

#[instrument(skip_all)]
async fn validate_input(
#[instrument(skip(self, inputs))]
pub async fn tokenize(
&self,
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(String, usize), ValidationError> {
) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.sender {
// Create response channel
Expand All @@ -79,7 +78,24 @@ impl Validation {

// Await on response channel
// Unwrap is safe here
let (inputs, input_length) = response_receiver.await.unwrap()?;
let encoding = response_receiver.await.unwrap()?;
Ok(Some(encoding))
} else {
Ok(None)
}
}

#[instrument(skip(self, inputs))]
async fn validate_input(
&self,
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(String, usize), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel
let input_length = encoding.len();

if let Some(max_new_tokens) = max_new_tokens {
// Get total tokens
Expand Down Expand Up @@ -352,36 +368,31 @@ fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver<TokenizerReq

/// Get input length and optionally truncate it
fn prepare_input(
inputs: String,
mut inputs: String,
truncate: Option<usize>,
tokenizer: &Tokenizer,
) -> Result<(String, usize), ValidationError> {
) -> Result<(tokenizers::Encoding, String), ValidationError> {
// Get the number of tokens in the input
let mut encoding = tokenizer
.encode(inputs.clone(), true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;

// Optionally truncate
let (inputs, input_length) = match truncate {
// Truncate is some and < encoding length
Some(truncate) if truncate < encoding.len() => {
// truncate encoding and decode new inputs
if let Some(truncate) = truncate {
if truncate < encoding.len() {
encoding.truncate(truncate, 0, TruncationDirection::Left);
let inputs = tokenizer
inputs = tokenizer
.decode(encoding.get_ids(), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
(inputs, encoding.len())
}
// Nothing to do
_ => (inputs, encoding.len()),
};
}

Ok((inputs, input_length))
Ok((encoding, inputs))
}

type TokenizerRequest = (
(String, Option<usize>),
oneshot::Sender<Result<(String, usize), ValidationError>>,
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
Span,
);

Expand Down

0 comments on commit 0b9117f

Please sign in to comment.