Skip to content

Commit

Permalink
feat: ernie support function calling (#631)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Jun 22, 2024
1 parent 1fd5c58 commit 250e0eb
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ AIChat is an all-in-one AI CLI tool featuring chat REPL, RAG, function calling,
- Bedrock: Llama-3/Claude-3.5/Claude-3/Mistral (paid, vision)
- Cloudflare (free, vision, embedding)
- Replicate (paid)
- Ernie (paid)
- Ernie (paid, embedding, rerank, function-calling)
- Qianwen: Qwen (paid, vision, embedding, function-calling)
- Moonshot (paid, function-calling)
- Deepseek (paid)
Expand Down
2 changes: 2 additions & 0 deletions models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,12 @@
max_input_tokens: 8192
input_price: 16.8
output_price: 16.8
supports_function_calling: true
- name: ernie-3.5-8k-0613
max_input_tokens: 8192
input_price: 1.68
output_price: 1.68
supports_function_calling: true
- name: ernie-speed-128k
max_input_tokens: 128000
input_price: 0
Expand Down
70 changes: 62 additions & 8 deletions src/client/ernie.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::access_token::*;
use super::*;

use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
Expand Down Expand Up @@ -93,7 +93,7 @@ impl ErnieClient {
&self.model.name(),
);

debug!("Ernie Re Rerank: {url} {body}");
debug!("Ernie Rerank Request: {url} {body}");

let builder = client.post(url).json(&body);

Expand Down Expand Up @@ -179,7 +179,17 @@ async fn chat_completions_streaming(
let handle = |message: SseMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(text) = data["result"].as_str() {
if let Some(function) = data["function_call"].as_object() {
if let (Some(name), Some(arguments)) = (
function.get("name").and_then(|v| v.as_str()),
function.get("arguments").and_then(|v| v.as_str()),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!("Tool call '{name}' is invalid: arguments must be in valid JSON format")
})?;
handler.tool_call(ToolCall::new(name.to_string(), arguments, None))?;
}
} else if let Some(text) = data["result"].as_str() {
handler.text(text)?;
}
Ok(false)
Expand Down Expand Up @@ -224,12 +234,37 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
mut messages,
temperature,
top_p,
functions: _,
functions,
stream,
} = data;

patch_system_message(&mut messages);

let messages: Vec<Value> = messages
.into_iter()
.flat_map(|message| {
let Message { role, content } = message;
match content {
MessageContent::ToolResults((tool_results, _)) => {
let mut list = vec![];
for tool_result in tool_results {
list.push(json!({
"role": "assistant",
"content": format!("Action: {}\nAction Input: {}", tool_result.call.name, tool_result.call.arguments)
}));
list.push(json!({
"role": "user",
"content": tool_result.output.to_string(),
}))

}
list
}
_ => vec![json!({ "role": role, "content": content })],
}
})
.collect();

let mut body = json!({
"messages": messages,
});
Expand All @@ -248,16 +283,35 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
body["stream"] = true.into();
}

if let Some(functions) = functions {
body["functions"] = json!(functions);
}

body
}

fn extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["result"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let text = data["result"].as_str().unwrap_or_default();

let mut tool_calls = vec![];
if let Some(call) = data["function_call"].as_object() {
if let (Some(name), Some(arguments)) = (
call.get("name").and_then(|v| v.as_str()),
call.get("arguments").and_then(|v| v.as_str()),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!("Tool call '{name}' is invalid: arguments must be in valid JSON format")
})?;
tool_calls.push(ToolCall::new(name.to_string(), arguments, None));
}
}

if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls: vec![],
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
output_tokens: data["usage"]["completion_tokens"].as_u64(),
Expand Down

0 comments on commit 250e0eb

Please sign in to comment.