From d6f766e9bc167bd012182108cfdb4aee277f1556 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Mon, 7 Oct 2024 12:58:14 +0200 Subject: [PATCH 1/3] Create Rust-based openai api proxy server in node hub --- Cargo.lock | 111 ++- Cargo.toml | 1 + node-hub/openai-proxy-server/Cargo.toml | 27 + node-hub/openai-proxy-server/src/error.rs | 75 ++ node-hub/openai-proxy-server/src/main.rs | 441 +++++++++ node-hub/openai-proxy-server/src/message.rs | 935 ++++++++++++++++++++ 6 files changed, 1553 insertions(+), 37 deletions(-) create mode 100644 node-hub/openai-proxy-server/Cargo.toml create mode 100644 node-hub/openai-proxy-server/src/error.rs create mode 100644 node-hub/openai-proxy-server/src/main.rs create mode 100644 node-hub/openai-proxy-server/src/message.rs diff --git a/Cargo.lock b/Cargo.lock index 0956a2d7e..7e92c1df6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -467,7 +467,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.2.6", + "indexmap 2.6.0", "lexical-core", "num", "serde", @@ -2504,6 +2504,27 @@ dependencies = [ "serde_yaml 0.8.26", ] +[[package]] +name = "dora-openai-proxy-server" +version = "0.3.6" +dependencies = [ + "chrono", + "dora-node-api", + "eyre", + "futures", + "hyper 0.14.29", + "indexmap 2.6.0", + "mime_guess", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tracing", + "url", + "uuid", +] + [[package]] name = "dora-operator-api" version = "0.3.6" @@ -3387,9 +3408,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -3413,9 +3434,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -3438,15 +3459,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -3456,9 +3477,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" @@ -3490,9 +3511,9 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", @@ -3501,15 +3522,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -3519,9 +3540,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -3797,7 +3818,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.2.6", + "indexmap 2.6.0", "slab", "tokio", "tokio-util", @@ -3838,6 +3859,12 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "hashbrown" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" + [[package]] name = "hassle-rs" version = "0.11.0" @@ -4016,7 +4043,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.10", + "socket2 0.5.7", "tokio", "tower-service", "tracing", @@ -4025,9 +4052,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.3.1" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" dependencies = [ "bytes", "futures-channel", @@ -4050,7 +4077,7 @@ checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-util", "rustls 0.23.10", "rustls-pki-types", @@ -4083,7 +4110,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.0", - "hyper 1.3.1", + "hyper 1.4.1", "pin-project-lite", "socket2 0.5.7", "tokio", @@ -4204,12 +4231,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.14.5", + "hashbrown 0.15.0", "serde", ] @@ -4909,6 +4936,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "mime_guess2" version = "2.0.5" @@ -5038,7 +5075,7 @@ dependencies = [ "bitflags 2.6.0", "codespan-reporting", "hexf-parse", - "indexmap 2.2.6", + "indexmap 2.6.0", "log", "num-traits", "rustc-hash", @@ -6109,7 +6146,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.2.6", + "indexmap 2.6.0", ] [[package]] @@ -6199,7 +6236,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9d34169e64b3c7a80c8621a48adaf44e0cf62c78a9b25dd9dd35f1881a17cf9" dependencies = [ "base64 0.21.7", - "indexmap 2.2.6", + "indexmap 2.6.0", "line-wrap", "quick-xml", "serde", @@ -7796,7 +7833,7 @@ dependencies = [ "egui_tiles", "glam", "half", - "indexmap 2.2.6", + "indexmap 2.6.0", "itertools 0.12.1", "macaw", "ndarray", @@ -8004,7 +8041,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.3.1", + "hyper 1.4.1", "hyper-rustls", "hyper-util", "ipnet", @@ -8734,7 +8771,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.6.0", "itoa", "ryu", "serde", @@ -9633,7 +9670,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.6.0", "toml_datetime", "winnow", ] @@ -9644,7 +9681,7 @@ version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.6.0", "toml_datetime", "winnow", ] @@ -10457,7 +10494,7 @@ dependencies = [ "bitflags 2.6.0", "cfg_aliases 0.1.1", "codespan-reporting", - "indexmap 2.2.6", + "indexmap 2.6.0", "log", "naga", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index c3feeff7c..a986f0ced 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ members = [ "node-hub/dora-record", "node-hub/dora-rerun", "node-hub/terminal-print", + "node-hub/openai-proxy-server", "libraries/extensions/ros2-bridge", "libraries/extensions/ros2-bridge/msg-gen", "libraries/extensions/ros2-bridge/python", diff --git a/node-hub/openai-proxy-server/Cargo.toml b/node-hub/openai-proxy-server/Cargo.toml new file mode 100644 index 000000000..d600355f2 --- /dev/null +++ b/node-hub/openai-proxy-server/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "dora-openai-proxy-server" +version.workspace = true +edition = "2021" +documentation.workspace = true +description.workspace = true +license.workspace = true +repository.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio = { version = "1.36.0", features = ["full"] } +dora-node-api = { workspace = true, features = ["tracing"] } +eyre = "0.6.8" +chrono = "0.4.31" +tracing = "0.1.27" +serde = { version = "1.0.130", features = ["derive"] } +serde_json = "1.0.68" +url = "2.2.2" +indexmap = { version = "2.6.0", features = ["serde"] } +hyper = { version = "0.14", features = ["full"] } +thiserror = "1.0.37" +uuid = { version = "1.10", features = ["v4"] } +mime_guess = "2.0.4" +futures = "0.3.31" +tokio-stream = "0.1.11" diff --git a/node-hub/openai-proxy-server/src/error.rs b/node-hub/openai-proxy-server/src/error.rs new file mode 100644 index 000000000..4da256a31 --- /dev/null +++ b/node-hub/openai-proxy-server/src/error.rs @@ -0,0 +1,75 @@ +// Forked from https://github.com/LlamaEdge/LlamaEdge/blob/6bfe9c12c85bf390c47d6065686caeca700feffa/llama-api-server/src/error.rs + +use hyper::{Body, Response}; +use tracing::error; + +#[allow(dead_code)] +pub(crate) fn not_implemented() -> Response { + // log error + error!(target: "stdout", "501 Not Implemented"); + + Response::builder() + .header("Access-Control-Allow-Origin", "*") + .header("Access-Control-Allow-Methods", "*") + .header("Access-Control-Allow-Headers", "*") + .status(hyper::StatusCode::NOT_IMPLEMENTED) + .body(Body::from("501 Not Implemented")) + .unwrap() +} + +pub(crate) fn internal_server_error(msg: impl AsRef) -> Response { + let err_msg = match msg.as_ref().is_empty() { + true => "500 Internal Server Error".to_string(), + false => format!("500 Internal Server Error: {}", msg.as_ref()), + }; + + // log error + error!(target: "stdout", "{}", &err_msg); + + Response::builder() + .header("Access-Control-Allow-Origin", "*") + .header("Access-Control-Allow-Methods", "*") + .header("Access-Control-Allow-Headers", "*") + .status(hyper::StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::from(err_msg)) + .unwrap() +} + +pub(crate) fn bad_request(msg: impl AsRef) -> Response { + let err_msg = match msg.as_ref().is_empty() { + true => "400 Bad Request".to_string(), + false => format!("400 Bad Request: {}", msg.as_ref()), + }; + + // log error + error!(target: "stdout", "{}", &err_msg); + + Response::builder() + .header("Access-Control-Allow-Origin", "*") + .header("Access-Control-Allow-Methods", "*") + .header("Access-Control-Allow-Headers", "*") + .status(hyper::StatusCode::BAD_REQUEST) + .body(Body::from(err_msg)) + .unwrap() +} + +pub(crate) fn invalid_endpoint(msg: impl AsRef) -> Response { + let err_msg = match msg.as_ref().is_empty() { + true => "404 The requested service endpoint is not found".to_string(), + false => format!( + "404 The requested service endpoint is not found: {}", + msg.as_ref() + ), + }; + + // log error + error!(target: "stdout", "{}", &err_msg); + + Response::builder() + .header("Access-Control-Allow-Origin", "*") + .header("Access-Control-Allow-Methods", "*") + .header("Access-Control-Allow-Headers", "*") + .status(hyper::StatusCode::NOT_FOUND) + .body(Body::from(err_msg)) + .unwrap() +} diff --git a/node-hub/openai-proxy-server/src/main.rs b/node-hub/openai-proxy-server/src/main.rs new file mode 100644 index 000000000..471cba184 --- /dev/null +++ b/node-hub/openai-proxy-server/src/main.rs @@ -0,0 +1,441 @@ +use dora_node_api::{self, dora_core::config::DataId, merged::MergeExternalSend, DoraNode, Event}; + +use eyre::{Context, ContextCompat}; +use futures::channel::oneshot::{self, Canceled}; +use hyper::{ + body::{to_bytes, Body, HttpBody}, + header, + server::conn::AddrStream, + service::{make_service_fn, service_fn}, + Request, Response, Server, StatusCode, +}; +use message::{ + ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage, + ChatCompletionRequest, ChatCompletionRequestMessage, Usage, +}; +use std::{ + collections::VecDeque, + net::SocketAddr, + path::{Path, PathBuf}, +}; +use tokio::{net::TcpListener, sync::mpsc}; +use tracing::{error, info}; + +mod error; +pub mod message; + +#[tokio::main] +async fn main() -> eyre::Result<()> { + let web_ui = Path::new("chatbot-ui"); + let port = 8000; + let addr = SocketAddr::from(([0, 0, 0, 0], port)); + + let (server_events_tx, server_events_rx) = mpsc::channel(3); + let server_events = tokio_stream::wrappers::ReceiverStream::new(server_events_rx); + + let server_result_tx = server_events_tx.clone(); + let new_service = make_service_fn(move |conn: &AddrStream| { + // log socket address + info!(target: "stdout", "remote_addr: {}, local_addr: {}", conn.remote_addr().to_string(), conn.local_addr().to_string()); + + // web ui + let web_ui = web_ui.to_string_lossy().to_string(); + let server_events_tx = server_events_tx.clone(); + async move { + let service = service_fn(move |req| { + handle_request(req, web_ui.clone(), server_events_tx.clone()) + }); + Ok::<_, eyre::Error>(service) + } + }); + + let tcp_listener = TcpListener::bind(addr).await.unwrap(); + info!(target: "stdout", "Listening on {}", addr); + + let server = Server::from_tcp(tcp_listener.into_std().unwrap()) + .unwrap() + .serve(new_service); + + tokio::spawn(async move { + let result = server.await.context("server task failed"); + if let Err(err) = server_result_tx.send(ServerEvent::Result(result)).await { + tracing::warn!("server result channel closed: {err}"); + } + }); + + let (mut node, events) = DoraNode::init_from_env()?; + + let merged = events.merge_external_send(server_events); + let events = futures::executor::block_on_stream(merged); + + let output_id = DataId::from("chat_completion_request".to_owned()); + let mut reply_channels = VecDeque::new(); + + for event in events { + match event { + dora_node_api::merged::MergedEvent::External(event) => match event { + ServerEvent::Result(server_result) => { + server_result.context("server failed")?; + break; + } + ServerEvent::ChatCompletionRequest { request, reply } => { + let message = request + .messages + .into_iter() + .find_map(|m| match m { + ChatCompletionRequestMessage::User(message) => Some(message), + _ => None, + }) + .context("no user message found"); + match message { + Ok(message) => match message.content() { + message::ChatCompletionUserMessageContent::Text(content) => { + node.send_output_bytes( + output_id.clone(), + Default::default(), + content.len(), + content.as_bytes(), + ) + .context("failed to send dora output")?; + reply_channels.push_back(( + reply, + content.as_bytes().len() as u64, + request.model, + )); + } + message::ChatCompletionUserMessageContent::Parts(_) => { + if reply + .send(Err(eyre::eyre!("unsupported message content"))) + .is_err() + { + tracing::warn!("failed to send chat completion reply because channel closed early"); + }; + } + }, + Err(err) => { + if reply.send(Err(err)).is_err() { + tracing::warn!("failed to send chat completion reply error because channel closed early"); + } + } + } + } + }, + dora_node_api::merged::MergedEvent::Dora(event) => match event { + Event::Input { + id, + data, + metadata: _, + } => { + match id.as_str() { + "completion_reply" => { + let (reply_channel, prompt_tokens, model) = + reply_channels.pop_front().context("no reply channel")?; + let data = TryFrom::try_from(&data) + .with_context(|| format!("invalid reply data: {data:?}")) + .map(|s: &[u8]| ChatCompletionObject { + id: format!("completion-{}", uuid::Uuid::new_v4()), + object: "chat.completion".to_string(), + created: chrono::Utc::now().timestamp() as u64, + model: model.unwrap_or_default(), + choices: vec![ChatCompletionObjectChoice { + index: 0, + message: ChatCompletionObjectMessage { + role: message::ChatCompletionRole::Assistant, + content: Some(String::from_utf8_lossy(s).to_string()), + tool_calls: Vec::new(), + function_call: None, + }, + finish_reason: message::FinishReason::stop, + logprobs: None, + }], + usage: Usage { + prompt_tokens, + completion_tokens: s.len() as u64, + total_tokens: prompt_tokens + s.len() as u64, + }, + }); + + if reply_channel.send(data).is_err() { + tracing::warn!("failed to send chat completion reply because channel closed early"); + } + } + _ => eyre::bail!("unexpected input id: {}", id), + }; + } + Event::Stop => { + break; + } + event => { + println!("Event: {event:#?}") + } + }, + } + } + + Ok(()) +} + +enum ServerEvent { + Result(eyre::Result<()>), + ChatCompletionRequest { + request: ChatCompletionRequest, + reply: oneshot::Sender>, + }, +} + +// Forked from https://github.com/LlamaEdge/LlamaEdge/blob/6bfe9c12c85bf390c47d6065686caeca700feffa/llama-api-server/src/main.rs +async fn handle_request( + req: Request, + web_ui: String, + request_tx: mpsc::Sender, +) -> Result, hyper::Error> { + let path_str = req.uri().path(); + let path_buf = PathBuf::from(path_str); + let mut path_iter = path_buf.iter(); + path_iter.next(); // Must be Some(OsStr::new(&path::MAIN_SEPARATOR.to_string())) + let root_path = path_iter.next().unwrap_or_default(); + let root_path = "/".to_owned() + root_path.to_str().unwrap_or_default(); + + // log request + { + let method = hyper::http::Method::as_str(req.method()).to_string(); + let path = req.uri().path().to_string(); + let version = format!("{:?}", req.version()); + if req.method() == hyper::http::Method::POST { + let size: u64 = match req.headers().get("content-length") { + Some(content_length) => content_length.to_str().unwrap().parse().unwrap(), + None => 0, + }; + + info!(target: "stdout", "method: {}, http_version: {}, content-length: {}", method, version, size); + info!(target: "stdout", "endpoint: {}", path); + } else { + info!(target: "stdout", "method: {}, http_version: {}", method, version); + info!(target: "stdout", "endpoint: {}", path); + } + } + + let response = match root_path.as_str() { + "/echo" => Response::new(Body::from("echo test")), + "/v1" => handle_llama_request(req, request_tx).await, + _ => static_response(path_str, web_ui), + }; + + // log response + { + let status_code = response.status(); + if status_code.as_u16() < 400 { + // log response + let response_version = format!("{:?}", response.version()); + info!(target: "stdout", "response_version: {}", response_version); + let response_body_size: u64 = response.body().size_hint().lower(); + info!(target: "stdout", "response_body_size: {}", response_body_size); + let response_status = status_code.as_u16(); + info!(target: "stdout", "response_status: {}", response_status); + let response_is_success = status_code.is_success(); + info!(target: "stdout", "response_is_success: {}", response_is_success); + } else { + let response_version = format!("{:?}", response.version()); + error!(target: "stdout", "response_version: {}", response_version); + let response_body_size: u64 = response.body().size_hint().lower(); + error!(target: "stdout", "response_body_size: {}", response_body_size); + let response_status = status_code.as_u16(); + error!(target: "stdout", "response_status: {}", response_status); + let response_is_success = status_code.is_success(); + error!(target: "stdout", "response_is_success: {}", response_is_success); + let response_is_client_error = status_code.is_client_error(); + error!(target: "stdout", "response_is_client_error: {}", response_is_client_error); + let response_is_server_error = status_code.is_server_error(); + error!(target: "stdout", "response_is_server_error: {}", response_is_server_error); + } + } + + Ok(response) +} + +fn static_response(path_str: &str, root: String) -> Response { + let path = match path_str { + "/" => "/index.html", + _ => path_str, + }; + + let mime = mime_guess::from_path(path); + + match std::fs::read(format!("{root}/{path}")) { + Ok(content) => Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, mime.first_or_text_plain().to_string()) + .body(Body::from(content)) + .unwrap(), + Err(_) => { + let body = Body::from(std::fs::read(format!("{root}/404.html")).unwrap_or_default()); + Response::builder() + .status(StatusCode::NOT_FOUND) + .header(header::CONTENT_TYPE, "text/html") + .body(body) + .unwrap() + } + } +} + +// Forked from https://github.com/LlamaEdge/LlamaEdge/blob/6bfe9c12c85bf390c47d6065686caeca700feffa/llama-api-server/src/backend/mod.rs#L8 +async fn handle_llama_request( + req: Request, + request_tx: mpsc::Sender, +) -> Response { + match req.uri().path() { + "/v1/chat/completions" => chat_completions_handler(req, request_tx).await, + // "/v1/completions" => ggml::completions_handler(req).await, + // "/v1/models" => ggml::models_handler().await, + // "/v1/embeddings" => ggml::embeddings_handler(req).await, + // "/v1/files" => ggml::files_handler(req).await, + // "/v1/chunks" => ggml::chunks_handler(req).await, + // "/v1/info" => ggml::server_info_handler().await, + // path if path.starts_with("/v1/files/") => ggml::files_handler(req).await, + path => error::invalid_endpoint(path), + } +} + +// Forked from https://github.com/LlamaEdge/LlamaEdge/blob/6bfe9c12c85bf390c47d6065686caeca700feffa/llama-api-server/src/backend/ggml.rs#L301 +async fn chat_completions_handler( + mut req: Request, + request_tx: mpsc::Sender, +) -> Response { + info!(target: "stdout", "Handling the coming chat completion request."); + + if req.method().eq(&hyper::http::Method::OPTIONS) { + let result = Response::builder() + .header("Access-Control-Allow-Origin", "*") + .header("Access-Control-Allow-Methods", "*") + .header("Access-Control-Allow-Headers", "*") + .header("Content-Type", "application/json") + .body(Body::empty()); + + match result { + Ok(response) => return response, + Err(e) => { + let err_msg = e.to_string(); + + // log + error!(target: "stdout", "{}", &err_msg); + + return error::internal_server_error(err_msg); + } + } + } + + info!(target: "stdout", "Prepare the chat completion request."); + + // parse request + let body_bytes = match to_bytes(req.body_mut()).await { + Ok(body_bytes) => body_bytes, + Err(e) => { + let err_msg = format!("Fail to read buffer from request body. {}", e); + + // log + error!(target: "stdout", "{}", &err_msg); + + return error::internal_server_error(err_msg); + } + }; + let mut chat_request: ChatCompletionRequest = match serde_json::from_slice(&body_bytes) { + Ok(chat_request) => chat_request, + Err(e) => { + let mut err_msg = format!("Fail to deserialize chat completion request: {}.", e); + + if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { + err_msg = format!("{}\njson_value: {}", err_msg, json_value); + } + + // log + error!(target: "stdout", "{}", &err_msg); + + return error::bad_request(err_msg); + } + }; + + // check if the user id is provided + if chat_request.user.is_none() { + chat_request.user = Some(gen_chat_id()) + }; + let id = chat_request.user.clone().unwrap(); + + // log user id + info!(target: "stdout", "user: {}", chat_request.user.clone().unwrap()); + + let (tx, rx) = oneshot::channel(); + if let Err(err) = request_tx + .send(ServerEvent::ChatCompletionRequest { + request: chat_request, + reply: tx, + }) + .await + .context("failed to send request") + { + return error::internal_server_error(format!("{err:?}")); + } + + let res = match rx + .await + .unwrap_or_else(|Canceled| Err(eyre::eyre!("result channel closed early"))) + { + Ok(chat_completion_object) => { + // serialize chat completion object + let s = match serde_json::to_string(&chat_completion_object) { + Ok(s) => s, + Err(e) => { + let err_msg = format!("Failed to serialize chat completion object. {}", e); + + // log + error!(target: "stdout", "{}", &err_msg); + + return error::internal_server_error(err_msg); + } + }; + + // return response + let result = Response::builder() + .header("Access-Control-Allow-Origin", "*") + .header("Access-Control-Allow-Methods", "*") + .header("Access-Control-Allow-Headers", "*") + .header("Content-Type", "application/json") + .header("user", id) + .body(Body::from(s)); + + match result { + Ok(response) => { + // log + info!(target: "stdout", "Finish chat completions in non-stream mode"); + + response + } + Err(e) => { + let err_msg = + format!("Failed chat completions in non-stream mode. Reason: {}", e); + + // log + error!(target: "stdout", "{}", &err_msg); + + error::internal_server_error(err_msg) + } + } + } + Err(e) => { + let err_msg = format!("Failed to get chat completions. Reason: {}", e); + + // log + error!(target: "stdout", "{}", &err_msg); + + error::internal_server_error(err_msg) + } + }; + + // log + info!(target: "stdout", "Send the chat completion response."); + + res +} + +pub(crate) fn gen_chat_id() -> String { + format!("chatcmpl-{}", uuid::Uuid::new_v4()) +} diff --git a/node-hub/openai-proxy-server/src/message.rs b/node-hub/openai-proxy-server/src/message.rs new file mode 100644 index 000000000..dff7e101c --- /dev/null +++ b/node-hub/openai-proxy-server/src/message.rs @@ -0,0 +1,935 @@ +use core::fmt; +use std::collections::HashMap; + +use indexmap::IndexMap; +use serde::{ + de::{self, MapAccess, Visitor}, + Deserialize, Deserializer, Serialize, +}; +use serde_json::Value; + +// Forked from https://github.com/LlamaEdge/LlamaEdge/blob/6bfe9c12c85bf390c47d6065686caeca700feffa/crates/endpoints/src/chat.rs#L304 +/// Represents a chat completion request. +#[derive(Debug, Serialize, Default)] +pub struct ChatCompletionRequest { + /// The model to use for generating completions. + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// A list of messages comprising the conversation so far. + pub messages: Vec, + /// Adjust the randomness of the generated text. Between 0.0 and 2.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// We generally recommend altering this or top_p but not both. + /// Defaults to 1.0. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + /// Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P. The value should be between 0.0 and 1.0. + /// + /// Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. + /// + /// We generally recommend altering this or temperature but not both. + /// Defaults to 1.0. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + /// How many chat completion choices to generate for each input message. + /// Defaults to 1. + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "n")] + pub n_choice: Option, + /// Whether to stream the results as they are generated. Useful for chatbots. + /// Defaults to false. + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + /// Options for streaming response. Only set this when you set `stream: true`. + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + /// A list of tokens at which to stop generation. If None, no stop tokens are used. Up to 4 sequences where the API will stop generating further tokens. + /// Defaults to None + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + /// The maximum number of tokens to generate. The value should be no less than 1. + /// Defaults to 1024. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + /// Defaults to 0.0. + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + /// Defaults to 0.0. + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + /// Modify the likelihood of specified tokens appearing in the completion. + /// + /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. + /// Defaults to None. + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + /// A unique identifier representing your end-user. + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + //* OpenAI specific parameters + /// **Deprecated since 0.10.0.** Use `tools` instead. + #[serde(skip_serializing_if = "Option::is_none")] + pub functions: Option>, + /// **Deprecated since 0.10.0.** Use `tool_choice` instead. + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + + /// Format that the model must output + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + /// A list of tools the model may call. + /// + /// Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + /// Controls which (if any) function is called by the model. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, +} +impl<'de> Deserialize<'de> for ChatCompletionRequest { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ChatCompletionRequestVisitor; + + impl<'de> Visitor<'de> for ChatCompletionRequestVisitor { + type Value = ChatCompletionRequest; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct ChatCompletionRequest") + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + // Initialize all fields as None or empty + let mut model = None; + let mut messages = None; + let mut temperature = None; + let mut top_p = None; + let mut n_choice = None; + let mut stream = None; + let mut stream_options = None; + let mut stop = None; + let mut max_tokens = None; + let mut presence_penalty = None; + let mut frequency_penalty = None; + let mut logit_bias = None; + let mut user = None; + let mut functions = None; + let mut function_call = None; + let mut response_format = None; + let mut tools = None; + let mut tool_choice = None; + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "model" => model = map.next_value()?, + "messages" => messages = map.next_value()?, + "temperature" => temperature = map.next_value()?, + "top_p" => top_p = map.next_value()?, + "n" => n_choice = map.next_value()?, + "stream" => stream = map.next_value()?, + "stream_options" => stream_options = map.next_value()?, + "stop" => stop = map.next_value()?, + "max_tokens" => max_tokens = map.next_value()?, + "presence_penalty" => presence_penalty = map.next_value()?, + "frequency_penalty" => frequency_penalty = map.next_value()?, + "logit_bias" => logit_bias = map.next_value()?, + "user" => user = map.next_value()?, + "functions" => functions = map.next_value()?, + "function_call" => function_call = map.next_value()?, + "response_format" => response_format = map.next_value()?, + "tools" => tools = map.next_value()?, + "tool_choice" => tool_choice = map.next_value()?, + _ => return Err(de::Error::unknown_field(key.as_str(), FIELDS)), + } + } + + // Ensure all required fields are initialized + let messages = messages.ok_or_else(|| de::Error::missing_field("messages"))?; + + // Set default value for `max_tokens` if not provided + if max_tokens.is_none() { + max_tokens = Some(1024); + } + + // Check tools and tool_choice + // `auto` is the default if tools are present. + // `none` is the default when no tools are present. + if tools.is_some() { + if tool_choice.is_none() { + tool_choice = Some(ToolChoice::Auto); + } + } else if tool_choice.is_none() { + tool_choice = Some(ToolChoice::None); + } + + if n_choice.is_none() { + n_choice = Some(1); + } + + if stream.is_none() { + stream = Some(false); + } + + // Construct ChatCompletionRequest with all fields + Ok(ChatCompletionRequest { + model, + messages, + temperature, + top_p, + n_choice, + stream, + stream_options, + stop, + max_tokens, + presence_penalty, + frequency_penalty, + logit_bias, + user, + functions, + function_call, + response_format, + tools, + tool_choice, + }) + } + } + + const FIELDS: &[&str] = &[ + "prompt", + "max_tokens", + "temperature", + "top_p", + "n", + "stream", + "stream_options", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "functions", + "function_call", + "response_format", + "tools", + "tool_choice", + ]; + deserializer.deserialize_struct( + "ChatCompletionRequest", + FIELDS, + ChatCompletionRequestVisitor, + ) + } +} + +/// Message for comprising the conversation. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum ChatCompletionRequestMessage { + System(ChatCompletionSystemMessage), + User(ChatCompletionUserMessage), + Assistant(ChatCompletionAssistantMessage), + Tool(ChatCompletionToolMessage), +} +impl ChatCompletionRequestMessage { + /// Creates a new system message. + /// + /// # Arguments + /// + /// * `content` - The contents of the system message. + /// + /// * `name` - An optional name for the participant. Provides the model information to differentiate between participants of the same role. + pub fn new_system_message(content: impl Into, name: Option) -> Self { + ChatCompletionRequestMessage::System(ChatCompletionSystemMessage::new(content, name)) + } + + /// Creates a new user message. + /// + /// # Arguments + /// + /// * `content` - The contents of the user message. + /// + /// * `name` - An optional name for the participant. Provides the model information to differentiate between participants of the same role. + pub fn new_user_message( + content: ChatCompletionUserMessageContent, + name: Option, + ) -> Self { + ChatCompletionRequestMessage::User(ChatCompletionUserMessage::new(content, name)) + } + + /// Creates a new assistant message. + /// + /// # Arguments + /// + /// * `content` - The contents of the assistant message. Required unless `tool_calls` is specified. + /// + /// * `name` - An optional name for the participant. Provides the model information to differentiate between participants of the same role. + /// + /// * `tool_calls` - The tool calls generated by the model. + pub fn new_assistant_message( + content: Option, + name: Option, + tool_calls: Option>, + ) -> Self { + ChatCompletionRequestMessage::Assistant(ChatCompletionAssistantMessage::new( + content, name, tool_calls, + )) + } + + /// Creates a new tool message. + pub fn new_tool_message(content: impl Into, tool_call_id: Option) -> Self { + ChatCompletionRequestMessage::Tool(ChatCompletionToolMessage::new(content, tool_call_id)) + } + + /// The role of the messages author. + pub fn role(&self) -> ChatCompletionRole { + match self { + ChatCompletionRequestMessage::System(_) => ChatCompletionRole::System, + ChatCompletionRequestMessage::User(_) => ChatCompletionRole::User, + ChatCompletionRequestMessage::Assistant(_) => ChatCompletionRole::Assistant, + ChatCompletionRequestMessage::Tool(_) => ChatCompletionRole::Tool, + } + } + + /// The name of the participant. Provides the model information to differentiate between participants of the same role. + pub fn name(&self) -> Option<&String> { + match self { + ChatCompletionRequestMessage::System(message) => message.name(), + ChatCompletionRequestMessage::User(message) => message.name(), + ChatCompletionRequestMessage::Assistant(message) => message.name(), + ChatCompletionRequestMessage::Tool(_) => None, + } + } +} + +/// Sampling methods used for chat completion requests. +#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq)] +pub enum ChatCompletionRequestSampling { + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + Temperature(f64), + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + TopP(f64), +} + +/// The role of the messages author. +#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ChatCompletionRole { + System, + User, + Assistant, + /// **Deprecated since 0.10.0.** Use [ChatCompletionRole::Tool] instead. + Function, + Tool, +} +impl std::fmt::Display for ChatCompletionRole { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ChatCompletionRole::System => write!(f, "system"), + ChatCompletionRole::User => write!(f, "user"), + ChatCompletionRole::Assistant => write!(f, "assistant"), + ChatCompletionRole::Function => write!(f, "function"), + ChatCompletionRole::Tool => write!(f, "tool"), + } + } +} + +/// **Deprecated since 0.10.0.** Use [Tool] instead. +#[derive(Debug, Deserialize, Serialize)] +pub struct ChatCompletionRequestFunction { + name: String, + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + parameters: ChatCompletionRequestFunctionParameters, +} + +/// The parameters the functions accepts, described as a JSON Schema object. +/// +/// See the [guide](https://platform.openai.com/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. +/// +/// To describe a function that accepts no parameters, provide the value +/// `{"type": "object", "properties": {}}`. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionRequestFunctionParameters { + #[serde(rename = "type")] + pub schema_type: JSONSchemaType, + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum JSONSchemaType { + Object, + Number, + Integer, + String, + Array, + Null, + Boolean, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSONSchemaDefine { + #[serde(rename = "type")] + pub schema_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(rename = "enum", skip_serializing_if = "Option::is_none")] + pub enum_values: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub items: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub maximum: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub minimum: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub examples: Option>, +} + +/// Represents a chat completion response returned by model, based on the provided input. +#[derive(Debug, Deserialize, Serialize)] +pub struct ChatCompletionObject { + /// A unique identifier for the chat completion. + pub id: String, + /// The object type, which is always `chat.completion`. + pub object: String, + /// The Unix timestamp (in seconds) of when the chat completion was created. + pub created: u64, + /// The model used for the chat completion. + pub model: String, + /// A list of chat completion choices. Can be more than one if `n_choice` is greater than 1. + pub choices: Vec, + /// Usage statistics for the completion request. + pub usage: Usage, +} + +/// Represents a chat completion choice returned by model. +#[derive(Debug, Deserialize, Serialize)] +pub struct ChatCompletionObjectChoice { + /// The index of the choice in the list of choices. + pub index: u32, + /// A chat completion message generated by the model. + pub message: ChatCompletionObjectMessage, + /// The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, or `function_call` if the model called a function. + pub finish_reason: FinishReason, + /// Log probability information for the choice. + pub logprobs: Option, +} + +/// Token usage +#[derive(Debug, Default, Deserialize, Serialize)] +pub struct Usage { + /// Number of tokens in the prompt. + pub prompt_tokens: u64, + /// Number of tokens in the generated completion. + pub completion_tokens: u64, + /// Total number of tokens used in the request (prompt + completion). + pub total_tokens: u64, +} + +/// The reason the model stopped generating tokens. +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] +#[allow(non_camel_case_types)] +pub enum FinishReason { + /// `stop` if the model hit a natural stop point or a provided stop sequence. + stop, + /// `length` if the maximum number of tokens specified in the request was reached. + length, + /// `tool_calls` if the model called a tool. + tool_calls, +} + +/// Represents a chat completion message generated by the model. +#[derive(Debug, Serialize)] +pub struct ChatCompletionObjectMessage { + /// The contents of the message. + pub content: Option, + /// The tool calls generated by the model, such as function calls. + #[serde(skip_serializing_if = "Vec::is_empty")] + pub tool_calls: Vec, + /// The role of the author of this message. + pub role: ChatCompletionRole, + /// Deprecated. The name and arguments of a function that should be called, as generated by the model. + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, +} +impl<'de> Deserialize<'de> for ChatCompletionObjectMessage { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ChatCompletionObjectMessageVisitor; + + impl<'de> Visitor<'de> for ChatCompletionObjectMessageVisitor { + type Value = ChatCompletionObjectMessage; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct ChatCompletionObjectMessage") + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut content = None; + let mut tool_calls = None; + let mut role = None; + let mut function_call = None; + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "content" => content = map.next_value()?, + "tool_calls" => tool_calls = map.next_value()?, + "role" => role = map.next_value()?, + "function_call" => function_call = map.next_value()?, + _ => return Err(de::Error::unknown_field(key.as_str(), FIELDS)), + } + } + + let content = content; + let tool_calls = tool_calls.unwrap_or_default(); + let role = role.ok_or_else(|| de::Error::missing_field("role"))?; + let function_call = function_call; + + Ok(ChatCompletionObjectMessage { + content, + tool_calls, + role, + function_call, + }) + } + } + + const FIELDS: &[&str] = &["content", "tool_calls", "role", "function_call"]; + deserializer.deserialize_struct( + "ChatCompletionObjectMessage", + FIELDS, + ChatCompletionObjectMessageVisitor, + ) + } +} + +/// Options for streaming response. Only set this when you set stream: `true``. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct StreamOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub include_usage: Option, +} + +/// The name and arguments of a function that should be called, as generated by the model. +#[derive(Debug, Deserialize, Serialize)] +pub struct ChatMessageFunctionCall { + /// The name of the function to call. + pub name: String, + + /// The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + pub arguments: String, +} + +/// Represents a tool call generated by the model. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct ToolCall { + /// The ID of the tool call. + pub id: String, + /// The type of the tool. Currently, only function is supported. + #[serde(rename = "type")] + pub ty: String, + /// The function that the model called. + pub function: Function, +} + +/// Log probability information for the choice. +#[derive(Debug, Deserialize, Serialize)] +pub struct LogProbs; + +/// The function that the model called. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct Function { + /// The name of the function that the model called. + pub name: String, + /// The arguments that the model called the function with. + pub arguments: String, +} + +/// Defines the types of a user message content. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[serde(untagged)] +pub enum ChatCompletionUserMessageContent { + /// The text contents of the message. + Text(String), + /// An array of content parts with a defined type, each can be of type `text` or `image_url` when passing in images. + /// It is required that there must be one content part of type `text` at least. Multiple images are allowed by adding multiple image_url content parts. + Parts(Vec), +} +impl ChatCompletionUserMessageContent { + pub fn ty(&self) -> &str { + match self { + ChatCompletionUserMessageContent::Text(_) => "text", + ChatCompletionUserMessageContent::Parts(_) => "parts", + } + } +} + +/// Define the content part of a user message. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "lowercase")] +// #[serde(untagged)] +pub enum ContentPart { + #[serde(rename = "text")] + Text(TextContentPart), + #[serde(rename = "image_url")] + Image(ImageContentPart), +} +impl ContentPart { + pub fn ty(&self) -> &str { + match self { + ContentPart::Text(_) => "text", + ContentPart::Image(_) => "image_url", + } + } +} + +/// Represents the text part of a user message content. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct TextContentPart { + /// The text content. + text: String, +} +impl TextContentPart { + pub fn new(text: impl Into) -> Self { + Self { text: text.into() } + } + + /// The text content. + pub fn text(&self) -> &str { + &self.text + } +} + +/// Represents the image part of a user message content. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct ImageContentPart { + #[serde(rename = "image_url")] + image: Image, +} +impl ImageContentPart { + pub fn new(image: Image) -> Self { + Self { image } + } + + /// The image URL. + pub fn image(&self) -> &Image { + &self.image + } +} + +/// JPEG baseline & progressive (12 bpc/arithmetic not supported, same as stock IJG lib) +/// PNG 1/2/4/8/16-bit-per-channel +/// +/// TGA (not sure what subset, if a subset) +/// BMP non-1bpp, non-RLE +/// PSD (composited view only, no extra channels, 8/16 bit-per-channel) +/// +/// GIF (*comp always reports as 4-channel) +/// HDR (radiance rgbE format) +/// PIC (Softimage PIC) +/// PNM (PPM and PGM binary only) +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +pub struct Image { + /// Either a URL of the image or the base64 encoded image data. + pub url: String, + /// Specifies the detail level of the image. Defaults to auto. + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, +} +impl Image { + pub fn is_url(&self) -> bool { + url::Url::parse(&self.url).is_ok() + } +} + +/// Defines the content of a tool message. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct ChatCompletionToolMessage { + /// The contents of the tool message. + content: String, + /// Tool call that this message is responding to. + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, +} +impl ChatCompletionToolMessage { + /// Creates a new tool message. + /// + /// # Arguments + /// + /// * `content` - The contents of the tool message. + /// + /// * `tool_call_id` - Tool call that this message is responding to. + pub fn new(content: impl Into, tool_call_id: Option) -> Self { + Self { + content: content.into(), + tool_call_id, + } + } + + /// The role of the messages author, in this case `tool`. + pub fn role(&self) -> ChatCompletionRole { + ChatCompletionRole::Tool + } + + /// The contents of the tool message. + pub fn content(&self) -> &str { + &self.content + } + + /// Tool call that this message is responding to. + pub fn tool_call_id(&self) -> Option { + self.tool_call_id.clone() + } +} + +/// Defines the content of an assistant message. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct ChatCompletionAssistantMessage { + /// The contents of the assistant message. Required unless `tool_calls` is specified. + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + /// An optional name for the participant. Provides the model information to differentiate between participants of the same role. + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + /// The tool calls generated by the model. + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, +} +impl ChatCompletionAssistantMessage { + /// Creates a new assistant message. + /// + /// # Arguments + /// + /// * `content` - The contents of the assistant message. Required unless `tool_calls` is specified. + /// + /// * `name` - An optional name for the participant. Provides the model information to differentiate between participants of the same role. + /// + /// * `tool_calls` - The tool calls generated by the model. + pub fn new( + content: Option, + name: Option, + tool_calls: Option>, + ) -> Self { + match tool_calls.is_some() { + true => Self { + content: None, + name, + tool_calls, + }, + false => Self { + content, + name, + tool_calls: None, + }, + } + } + + /// The role of the messages author, in this case `assistant`. + pub fn role(&self) -> ChatCompletionRole { + ChatCompletionRole::Assistant + } + + /// The contents of the assistant message. If `tool_calls` is specified, then `content` is None. + pub fn content(&self) -> Option<&String> { + self.content.as_ref() + } + + /// An optional name for the participant. + pub fn name(&self) -> Option<&String> { + self.name.as_ref() + } + + /// The tool calls generated by the model. + pub fn tool_calls(&self) -> Option<&Vec> { + self.tool_calls.as_ref() + } +} + +/// Defines the content of a system message. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct ChatCompletionSystemMessage { + /// The contents of the system message. + content: String, + /// An optional name for the participant. Provides the model information to differentiate between participants of the same role. + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, +} +impl ChatCompletionSystemMessage { + /// Creates a new system message. + /// + /// # Arguments + /// + /// * `content` - The contents of the system message. + /// + /// * `name` - An optional name for the participant. Provides the model information to differentiate between participants of the same role. + pub fn new(content: impl Into, name: Option) -> Self { + Self { + content: content.into(), + name, + } + } + + pub fn role(&self) -> ChatCompletionRole { + ChatCompletionRole::System + } + + pub fn content(&self) -> &str { + &self.content + } + + pub fn name(&self) -> Option<&String> { + self.name.as_ref() + } +} + +/// Defines the content of a user message. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct ChatCompletionUserMessage { + /// The contents of the user message. + content: ChatCompletionUserMessageContent, + /// An optional name for the participant. Provides the model information to differentiate between participants of the same role. + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, +} +impl ChatCompletionUserMessage { + /// Creates a new user message. + /// + /// # Arguments + /// + /// * `content` - The contents of the user message. + /// + /// * `name` - An optional name for the participant. Provides the model information to differentiate between participants of the same role. + pub fn new(content: ChatCompletionUserMessageContent, name: Option) -> Self { + Self { content, name } + } + + pub fn role(&self) -> ChatCompletionRole { + ChatCompletionRole::User + } + + pub fn content(&self) -> &ChatCompletionUserMessageContent { + &self.content + } + + pub fn name(&self) -> Option<&String> { + self.name.as_ref() + } +} + +/// An object specifying the format that the model must output. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatResponseFormat { + /// Must be one of `text`` or `json_object`. Defaults to `text`. + #[serde(rename = "type")] + pub ty: String, +} +impl Default for ChatResponseFormat { + fn default() -> Self { + Self { + ty: "text".to_string(), + } + } +} + +/// Controls which (if any) function is called by the model. Defaults to `None`. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub enum ToolChoice { + /// The model will not call a function and instead generates a message. + #[serde(rename = "none")] + None, + /// The model can pick between generating a message or calling a function. + #[serde(rename = "auto")] + Auto, + /// The model must call one or more tools. + #[serde(rename = "required")] + Required, + /// Specifies a tool the model should use. Use to force the model to call a specific function. + #[serde(untagged)] + Tool(ToolChoiceTool), +} +impl Default for ToolChoice { + fn default() -> Self { + Self::None + } +} + +/// A tool the model should use. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct ToolChoiceTool { + /// The type of the tool. Currently, only `function` is supported. + #[serde(rename = "type")] + pub ty: String, + /// The function the model calls. + pub function: ToolChoiceToolFunction, +} + +/// Represents a tool the model should use. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct ToolChoiceToolFunction { + /// The name of the function to call. + pub name: String, +} + +/// Represents a tool the model may generate JSON inputs for. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Tool { + /// The type of the tool. Currently, only `function` is supported. + #[serde(rename = "type")] + pub ty: String, + /// Function the model may generate JSON inputs for. + pub function: ToolFunction, +} + +/// Function the model may generate JSON inputs for. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolFunction { + /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + pub name: String, + /// A description of what the function does, used by the model to choose when and how to call the function. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + // The parameters the functions accepts, described as a JSON Schema object. + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +/// The parameters the functions accepts, described as a JSON Schema object. +/// +/// See the [guide](https://platform.openai.com/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. +/// +/// To describe a function that accepts no parameters, provide the value +/// `{"type": "object", "properties": {}}`. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolFunctionParameters { + #[serde(rename = "type")] + pub schema_type: JSONSchemaType, + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option>, +} From 10f17a1cd2f93137126c84c2c21ba28c70b3f954 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Mon, 7 Oct 2024 17:13:54 +0200 Subject: [PATCH 2/3] Add dataflow file for Rust-based openai-server to `openai-server` example --- examples/openai-server/dataflow-rust.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 examples/openai-server/dataflow-rust.yml diff --git a/examples/openai-server/dataflow-rust.yml b/examples/openai-server/dataflow-rust.yml new file mode 100644 index 000000000..8c6a1d8d0 --- /dev/null +++ b/examples/openai-server/dataflow-rust.yml @@ -0,0 +1,16 @@ +nodes: + - id: dora-openai-server + build: cargo build -p dora-openai-proxy-server --release + path: ../../target/release/dora-openai-proxy-server + outputs: + - chat_completion_request + inputs: + completion_reply: dora-echo/echo + + - id: dora-echo + build: pip install -e ../../node-hub/dora-echo + path: dora-echo + inputs: + echo: dora-openai-server/chat_completion_request + outputs: + - echo From f2793f5a8871d8b54b530177cdc709341ef72d48 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Mon, 7 Oct 2024 18:02:02 +0200 Subject: [PATCH 3/3] Add basic support for stream-based completion interface --- node-hub/openai-proxy-server/src/main.rs | 129 +++++++++++++++-------- 1 file changed, 85 insertions(+), 44 deletions(-) diff --git a/node-hub/openai-proxy-server/src/main.rs b/node-hub/openai-proxy-server/src/main.rs index 471cba184..c0714886d 100644 --- a/node-hub/openai-proxy-server/src/main.rs +++ b/node-hub/openai-proxy-server/src/main.rs @@ -1,7 +1,10 @@ use dora_node_api::{self, dora_core::config::DataId, merged::MergeExternalSend, DoraNode, Event}; use eyre::{Context, ContextCompat}; -use futures::channel::oneshot::{self, Canceled}; +use futures::{ + channel::oneshot::{self, Canceled}, + TryStreamExt, +}; use hyper::{ body::{to_bytes, Body, HttpBody}, header, @@ -362,6 +365,7 @@ async fn chat_completions_handler( // log user id info!(target: "stdout", "user: {}", chat_request.user.clone().unwrap()); + let stream = chat_request.stream; let (tx, rx) = oneshot::channel(); if let Err(err) = request_tx @@ -375,58 +379,95 @@ async fn chat_completions_handler( return error::internal_server_error(format!("{err:?}")); } - let res = match rx - .await - .unwrap_or_else(|Canceled| Err(eyre::eyre!("result channel closed early"))) - { - Ok(chat_completion_object) => { - // serialize chat completion object - let s = match serde_json::to_string(&chat_completion_object) { - Ok(s) => s, - Err(e) => { - let err_msg = format!("Failed to serialize chat completion object. {}", e); + let res = if let Some(true) = stream { + let result = async { + let chat_completion_object = rx + .await + .unwrap_or_else(|Canceled| Err(eyre::eyre!("result channel closed early")))?; + serde_json::to_string(&chat_completion_object).context("failed to serialize response") + }; + let stream = futures::stream::once(result).map_err(|e| e.to_string()); - // log - error!(target: "stdout", "{}", &err_msg); + let result = Response::builder() + .header("Access-Control-Allow-Origin", "*") + .header("Access-Control-Allow-Methods", "*") + .header("Access-Control-Allow-Headers", "*") + .header("Content-Type", "text/event-stream") + .header("Cache-Control", "no-cache") + .header("Connection", "keep-alive") + .header("user", id) + .body(Body::wrap_stream(stream)); - return error::internal_server_error(err_msg); - } - }; + match result { + Ok(response) => { + // log + info!(target: "stdout", "finish chat completions in stream mode"); - // return response - let result = Response::builder() - .header("Access-Control-Allow-Origin", "*") - .header("Access-Control-Allow-Methods", "*") - .header("Access-Control-Allow-Headers", "*") - .header("Content-Type", "application/json") - .header("user", id) - .body(Body::from(s)); - - match result { - Ok(response) => { - // log - info!(target: "stdout", "Finish chat completions in non-stream mode"); - - response - } - Err(e) => { - let err_msg = - format!("Failed chat completions in non-stream mode. Reason: {}", e); + response + } + Err(e) => { + let err_msg = format!("Failed chat completions in stream mode. Reason: {}", e); - // log - error!(target: "stdout", "{}", &err_msg); + // log + error!(target: "stdout", "{}", &err_msg); - error::internal_server_error(err_msg) - } + error::internal_server_error(err_msg) } } - Err(e) => { - let err_msg = format!("Failed to get chat completions. Reason: {}", e); + } else { + match rx + .await + .unwrap_or_else(|Canceled| Err(eyre::eyre!("result channel closed early"))) + { + Ok(chat_completion_object) => { + // serialize chat completion object + let s = match serde_json::to_string(&chat_completion_object) { + Ok(s) => s, + Err(e) => { + let err_msg = format!("Failed to serialize chat completion object. {}", e); + + // log + error!(target: "stdout", "{}", &err_msg); + + return error::internal_server_error(err_msg); + } + }; + + // return response + let result = Response::builder() + .header("Access-Control-Allow-Origin", "*") + .header("Access-Control-Allow-Methods", "*") + .header("Access-Control-Allow-Headers", "*") + .header("Content-Type", "application/json") + .header("user", id) + .body(Body::from(s)); + + match result { + Ok(response) => { + // log + info!(target: "stdout", "Finish chat completions in non-stream mode"); + + response + } + Err(e) => { + let err_msg = + format!("Failed chat completions in non-stream mode. Reason: {}", e); - // log - error!(target: "stdout", "{}", &err_msg); + // log + error!(target: "stdout", "{}", &err_msg); - error::internal_server_error(err_msg) + error::internal_server_error(err_msg) + } + } + } + Err(e) => { + let err_msg = format!("Failed to get chat completions. Reason: {}", e); + + // log + error!(target: "stdout", "{}", &err_msg); + + error::internal_server_error(err_msg) + } } };