Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api-server Streamer #67

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions candle_demo/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
rand = "0.8.5"
owo-colors = "4.0.0"
codegeex4 = {path = "./codegeex4"}



api-server = {path = "./api-server"}
flume = "0.11.0"
serde = { version = "1.0.204", features = ["derive"] }
futures = "0.3.30"
axum = "0.7.5"
tokio = {version = "1.39.1", features = ["full"]}
uuid = { version = "1.10.0", features = ["v4"] }
10 changes: 9 additions & 1 deletion candle_demo/README.org
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@ THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios,
- [[https://huggingface.co/THUDM/codegeex4-all-9b][huggingface]]
- [[https://github.com/huggingface/candle/blob/main/candle-examples/examples/codegeex4-9b/README.org][Candle]]

- 目前openai-api正在开发中
** api-server
#+begin_src shell
cargo build --release -p api-server --features cuda
./target/release/api-server 0.0.0.0:3000
#+end_src

[[file:../resources/rust-api-server.png][file:../resources/rust-api-server.png]]

** Cli
#+begin_src shell
cargo build --release -p codegeex4-cli # Cpu
Expand Down Expand Up @@ -86,3 +93,4 @@ THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios,
year={2023}
}
#+end_src
** Candle-vllm 利用了部分candle-vllm的代码
16 changes: 11 additions & 5 deletions candle_demo/api-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@ intel-mkl-src = { workspace = true ,optional = true}
rand = { workspace = true}
owo-colors = {workspace = true}
codegeex4 = {workspace = true}
tokio = {version = "1.39.1", features = ["full"]}
actix-web = "4.8.0"
serde = { version = "1.0.204", features = ["derive"] }
shortuuid = "0.0.1"
short-uuid = "0.1.2"
# for async runtime and excutor
tokio= {workspace=true}

# for api-server
axum = {workspace= true}
serde = {workspace=true}
# for uuid generation
tower-http = { version = "0.5.2", features = ["cors"] }
uuid= {workspace=true}
futures = {workspace = true}
flume = {workspace=true}
[build-dependencies]
bindgen_cuda = { version = "0.1.1", optional = true }
[features]
Expand Down
123 changes: 0 additions & 123 deletions candle_demo/api-server/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,123 +0,0 @@
use actix_web::{
get, post,
web::{self, Data},
HttpRequest, Responder,
};
use owo_colors::OwoColorize;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub temperature: f64,
pub top_p: f64,
pub max_tokens: usize,
pub stop: Vec<String>,
pub stream: bool,
pub presence_penalty: Option<f32>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct DeltaMessage {
pub role: String,
pub content: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionResponseStreamChoice {
pub index: i32,
pub delta: DeltaMessage,
pub finish_reason: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionStreamResponse {
pub id: String,
pub object: String,
pub created: i32,
pub model: String,
pub choices: Vec<ChatCompletionResponseStreamChoice>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionResponseChoice {
pub index: i32,
pub message: ChatMessage,
pub finish_reason: Option<FinishResaon>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletionResponseChoice>,
}

#[derive(Debug, Serialize, Deserialize)]
pub enum FinishResaon{
STOP,
LENGTH,
}
use std::time::{SystemTime, UNIX_EPOCH};
impl ChatCompletionResponse {
pub fn empty() -> Self {
let current_time = SystemTime::now();
Self {
id: format!("chatcmpl-{}", short_uuid::ShortUuid::generate()),
object: "chat.completion".to_string(),
created: current_time
.duration_since(UNIX_EPOCH)
.expect("failed to get time")
.as_secs()
.into(),
model: "codegeex4".to_string(),
choices: vec![ChatCompletionResponseChoice::empty()],
}
}
}

impl ChatCompletionResponseChoice {
pub fn empty() -> Self {
Self {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: "".to_string(),
},
finish_reason: None,
}
}
}

impl ChatCompletionRequest {
pub fn empty() -> Self {
Self{
model: "codegeex4".to_string(),
messages: vec!(ChatMessage {
role: "assistant".to_string(),
content: "".to_string(),
}),
temperature: 0.2_f64,
top_p: 0.2_f64,
max_tokens: 1024_usize,
stop: vec!("<|user|>".to_string(), "<|assistant|>".to_string(), "<|observation|>".to_string(), "<|endoftext|>".to_string()),
stream: true,
presence_penalty: None,
}
}
}

// impl DeltaMessage {
// pub fn new() -> Self {
// role:
// }
// }
46 changes: 46 additions & 0 deletions candle_demo/api-server/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,50 @@ pub struct Args {
pub address: String,
#[arg(short, long, default_value_t = 1)]
pub workers: usize,
/// Run on CPU rather than on GPU.
#[arg(name = "cache", short, long, default_value = ".")]
pub cache_path: String,

#[arg(long)]
pub cpu: bool,

/// Display the token for the specified prompt.
#[arg(long)]
pub verbose_prompt: bool,

/// The temperature used to generate samples.
#[arg(long)]
pub temperature: Option<f64>,

/// Nucleus sampling probability cutoff.
#[arg(long)]
pub top_p: Option<f64>,

/// The seed to use when generating random samples.
#[arg(long)]
pub seed: Option<u64>,

/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
pub sample_len: usize,

#[arg(long)]
pub model_id: Option<String>,

#[arg(long)]
pub revision: Option<String>,

#[arg(long)]
pub weight_file: Option<String>,

#[arg(long)]
pub tokenizer: Option<String>,

/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
pub repeat_penalty: f32,

/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
pub repeat_last_n: usize,
}
104 changes: 99 additions & 5 deletions candle_demo/api-server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,113 @@
mod api;
mod args;
mod server;
mod model;
mod server;
mod streamer;
use axum::{routing, Router};
use candle_core::DType;
use candle_nn::VarBuilder;
use clap::Parser;
use owo_colors::OwoColorize;
use codegeex4::codegeex4::*;
use codegeex4::TextGenerationApiServer;
use hf_hub::{Repo, RepoType};
use owo_colors::{self, OwoColorize};
use rand::Rng;
use server::chat;
use std::sync::{Arc, Mutex};
use tokenizers::Tokenizer;
use tower_http::cors::{AllowOrigin, CorsLayer};

pub struct Data {
pub pipeline: Mutex<TextGenerationApiServer>,
}

#[tokio::main]
async fn main() {
let args = args::Args::parse();
let server = server::Server::new(args.clone());
println!(
"{} Server Binding On {} with {} workers",
"[INFO]".green(),
&args.address.purple(),
&args.workers.purple()
);
server.run().await;

let mut seed: u64 = 0;
if let Some(_seed) = args.seed {
seed = _seed;
} else {
let mut rng = rand::thread_rng();
seed = rng.gen();
}
println!("Using Seed {}", seed.red());
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
.build()
.unwrap();

let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "THUDM/codegeex4-all-9b".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.unwrap(),
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json").unwrap()
}
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
let start = std::time::Instant::now();
let config = Config::codegeex4();
let device = candle_examples::device(args.cpu).unwrap();
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
println!("DType is {:?}", dtype.yellow());
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device).unwrap() };
let model = Model::new(&config, vb).unwrap();

println!("模型加载完毕 {:?}", start.elapsed().as_secs().green());

let pipeline = TextGenerationApiServer::new(
model,
tokenizer,
seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
let server_data = Arc::new(Data {
pipeline: Mutex::new(pipeline),
});

let allow_origin = AllowOrigin::any();
let allow_methods = tower_http::cors::AllowMethods::any();
let allow_headers = tower_http::cors::AllowHeaders::any();
let cors_layer = CorsLayer::new()
.allow_methods(allow_methods)
.allow_headers(allow_headers)
.allow_origin(allow_origin);
let chat = Router::new()
// .route("/v1/chat/completions", routing::post(raw))
.route("/v1/chat/completions", routing::post(chat))
.layer(cors_layer)
.with_state(server_data);
// .with_state(Arc::new(server_data));
let listener = tokio::net::TcpListener::bind(args.address).await.unwrap();
axum::serve(listener, chat).await.unwrap();
}
7 changes: 1 addition & 6 deletions candle_demo/api-server/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
use codegeex4::codegeex4::Config;
use crate::api::ChatCompletionRequest;
fn stream_chat(request: ChatCompletionRequest) {
let default_config = codegeex4::codegeex4::Config::codegeex4();

}

Loading