diff --git a/router/src/infer.rs b/router/src/infer.rs index 8e074745b..c950e311f 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -181,6 +181,27 @@ impl Infer { Ok((permit, response_rx.into_stream())) } + /// Tokenizer the input + #[instrument(skip_all)] + pub(crate) async fn tokenize( + &self, + request: GenerateRequest, + ) -> Result, InferError> { + // Tokenize request + let inputs = request.inputs; + let truncate = request.parameters.truncate; + let encoding = self + .validation + .tokenize(inputs, truncate) + .await + .map_err(|err| { + tracing::error!("Error occured during tokenization. {err}"); + err + })?; + + // Return Encoding + Ok(encoding.map(|(encoding, _)| encoding)) + } /// Add a new request to the queue and return a InferResponse #[instrument(skip(self))] pub(crate) async fn generate( diff --git a/router/src/lib.rs b/router/src/lib.rs index 701e6ff1c..728e13252 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -323,6 +323,14 @@ impl From for GenerateRequest { } } +#[derive(Clone, Debug, Deserialize, ToSchema)] +pub(crate) struct TokenizeRequest { + #[schema(example = "My name is Olivier and I")] + pub inputs: String, + #[schema(nullable = true, example = true)] + pub add_special_tokens: Option, +} + #[derive(Debug, Serialize, ToSchema)] pub struct PrefillToken { #[schema(example = 0)] @@ -358,6 +366,18 @@ pub struct Token { alternative_tokens: Option>, } +#[derive(Debug, Serialize, ToSchema)] +pub struct SimpleToken { + #[schema(example = 0)] + id: u32, + #[schema(example = "test")] + text: String, + #[schema(example = 0)] + start: usize, + #[schema(example = 2)] + stop: usize, +} + #[derive(Serialize, ToSchema)] #[serde(rename_all(serialize = "snake_case"))] pub(crate) enum FinishReason { @@ -408,6 +428,10 @@ pub(crate) struct GenerateResponse { pub details: Option
, } +#[derive(Serialize, ToSchema)] +#[serde(transparent)] +pub(crate) struct TokenizeResponse(Vec); + #[derive(Serialize, ToSchema)] pub(crate) struct StreamDetails { #[schema(example = "length")] @@ -551,7 +575,9 @@ struct CompletionStreamResponse { #[derive(Serialize, ToSchema)] struct ChatMessage { + #[serde(skip_serializing_if = "Option::is_none")] role: Option, + #[serde(skip_serializing_if = "Option::is_none")] content: Option, } diff --git a/router/src/server.rs b/router/src/server.rs index ad85ea801..57bd2335f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -5,10 +5,14 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::json; use crate::validation::ValidationError; use crate::{ - BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, - CompatGenerateRequest, CompletionRequest, CompletionResponse, CompletionStreamResponse, - Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, - HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation, + AdapterParameters, AlternativeToken, BestOfSequence, ChatCompletionRequest, + ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionStreamResponse, + ChatCompletionStreamResponseChoice, ChatMessage, CompatGenerateRequest, CompletionFinishReason, + CompletionRequest, CompletionResponse, CompletionResponseChoice, + CompletionResponseStreamChoice, CompletionStreamResponse, Details, ErrorResponse, FinishReason, + GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, LogProbs, + PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse, + Token, TokenizeRequest, TokenizeResponse, UsageInfo, Validation, }; use axum::extract::Extension; use axum::http::{request, HeaderMap, Method, StatusCode}; @@ -28,6 +32,7 @@ use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; use std::sync::Arc; +use std::sync::Mutex; use tokenizers::Tokenizer; use tokio::signal; use tokio::sync::mpsc; @@ -108,7 +113,7 @@ async fn compat_generate( /// OpenAI compatible completions endpoint #[utoipa::path( post, -tag = "LoRAX", +tag = "OpenAI Compatible", path = "/v1/completions", request_body = CompletionRequest, responses( @@ -190,7 +195,7 @@ async fn completions_v1( /// OpenAI compatible chat completions endpoint #[utoipa::path( post, -tag = "LoRAX", +tag = "OpenAI Compatible", path = "/v1/chat/completions", request_body = ChatCompletionRequest, responses( @@ -899,16 +904,27 @@ pub async fn run( compat_generate, generate, generate_stream, + completions_v1, + chat_completions_v1, + tokenize, metrics, ), components( schemas( Info, + UsageInfo, + ResponseFormat, + ResponseFormatType, CompatGenerateRequest, GenerateRequest, GenerateParameters, + AdapterParameters, + AlternativeToken, PrefillToken, Token, + SimpleToken, + TokenizeRequest, + TokenizeResponse, GenerateResponse, BestOfSequence, Details, @@ -916,10 +932,25 @@ pub async fn run( StreamResponse, StreamDetails, ErrorResponse, + ChatMessage, + LogProbs, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionStreamResponse, + CompletionResponseStreamChoice, + CompletionFinishReason, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionStreamResponse, + ChatCompletionStreamResponseChoice, ) ), tags( - (name = "LoRAX", description = "LoRAX API") + (name = "LoRAX", description = "LoRAX API"), + (name = "OpenAI Compatible", description = "OpenAI compatible API"), + (name = "Tokenization", description = "Tokenizer API"), ), info( title = "LoRAX", @@ -931,6 +962,8 @@ pub async fn run( )] struct ApiDoc; + let cloned_tokenizer = tokenizer.clone().map(|t| Arc::new(Mutex::new(t))); + // Create state let validation = Validation::new( validation_workers, @@ -1087,6 +1120,7 @@ pub async fn run( .route("/ping", get(health)) // Prometheus metrics route .route("/metrics", get(metrics)) + .route("/tokenize", post(tokenize)) .layer(Extension(info)) .layer(Extension(request_logger_sender.clone())) .layer(Extension(health_ext.clone())) @@ -1094,7 +1128,8 @@ pub async fn run( .layer(Extension(infer)) .layer(Extension(prom_handle.clone())) .layer(opentelemetry_tracing_layer()) - .layer(cors_layer); + .layer(cors_layer) + .layer(Extension(cloned_tokenizer)); if ngrok { #[cfg(feature = "ngrok")] @@ -1229,3 +1264,56 @@ impl From for Event { .unwrap() } } + +/// Tokenize inputs +#[utoipa::path( + post, + tag = "Tokenization", + path = "/tokenize", + request_body = TokenizeRequest, + responses( + (status = 200, description = "Tokenized ids", body = TokenizeResponse), + (status = 404, description = "No tokenizer found", body = ErrorResponse, + example = json ! ({"error": "No fast tokenizer available"})), + ) + )] +#[instrument(skip_all)] +async fn tokenize( + Extension(cloned_tokenizer): Extension>>>, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + if let Some(tokenizer) = cloned_tokenizer { + let input = req.inputs.clone(); + let add_special_tokens = match req.add_special_tokens { + None => true, + _ => req.add_special_tokens.unwrap(), + }; + let tokenizer = tokenizer.lock().unwrap(); + let char_offset = tokenizer + .encode_char_offsets(&input[..], add_special_tokens) + .unwrap(); + let tokens: Vec = char_offset + .get_ids() + .iter() + .zip(char_offset.get_offsets().iter()) + .map(|(&id, &(start, stop))| { + let text: String = tokenizer.id_to_token(id).unwrap(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect(); + Ok(Json(TokenizeResponse(tokens))) + } else { + Err(( + StatusCode::NOT_FOUND, + Json(ErrorResponse { + error: "No fast tokenizer or tokenizer.json for this model".to_string(), + error_type: "no fast tokenizer".to_string(), + }), + )) + } +} diff --git a/router/src/validation.rs b/router/src/validation.rs index 0a7c5f7f5..8264ea0a7 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -60,13 +60,12 @@ impl Validation { } } - #[instrument(skip_all)] - async fn validate_input( + #[instrument(skip(self, inputs))] + pub async fn tokenize( &self, inputs: String, truncate: Option, - max_new_tokens: Option, - ) -> Result<(String, usize), ValidationError> { + ) -> Result, ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -79,7 +78,24 @@ impl Validation { // Await on response channel // Unwrap is safe here - let (inputs, input_length) = response_receiver.await.unwrap()?; + let encoding = response_receiver.await.unwrap()?; + Ok(Some(encoding)) + } else { + Ok(None) + } + } + + #[instrument(skip(self, inputs))] + async fn validate_input( + &self, + inputs: String, + truncate: Option, + max_new_tokens: Option, + ) -> Result<(String, usize), ValidationError> { + // If we have a fast tokenizer + if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { + // Create response channel + let input_length = encoding.len(); if let Some(max_new_tokens) = max_new_tokens { // Get total tokens @@ -352,36 +368,31 @@ fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver, tokenizer: &Tokenizer, -) -> Result<(String, usize), ValidationError> { +) -> Result<(tokenizers::Encoding, String), ValidationError> { // Get the number of tokens in the input let mut encoding = tokenizer .encode(inputs.clone(), true) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; // Optionally truncate - let (inputs, input_length) = match truncate { - // Truncate is some and < encoding length - Some(truncate) if truncate < encoding.len() => { - // truncate encoding and decode new inputs + if let Some(truncate) = truncate { + if truncate < encoding.len() { encoding.truncate(truncate, 0, TruncationDirection::Left); - let inputs = tokenizer + inputs = tokenizer .decode(encoding.get_ids(), false) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; - (inputs, encoding.len()) } - // Nothing to do - _ => (inputs, encoding.len()), - }; + } - Ok((inputs, input_length)) + Ok((encoding, inputs)) } type TokenizerRequest = ( (String, Option), - oneshot::Sender>, + oneshot::Sender>, Span, );