Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support to OpenAI chat APIs #17

Merged
merged 1 commit into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ description = "A client for the OpenAI API using the Clap library"
category = "command-line-utilities"
keywords = ["openai", "api", "client", "clap"]

[[example]]
name = "clap-chat"
path = "examples/clap/chat.rs"

[package.metadata.example.clap-chat]
name = "OpenAI Chat Clap Client"
description = "A client for the OpenAI API using the Clap library"
category = "command-line-utilities"
keywords = ["openai", "api", "client", "clap"]

[dev-dependencies]
clap = { version = "4.0.29", features = ["derive"] }
Expand Down
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,26 @@ Feel free to modify the configuration file to suit your needs.

There are some examples available in the `examples` folder:


This example calls the `text2img` function and display the API response without any parsing:

```bash
cargo run --example text2img
```

An example demonstrating how to use the clap library to create a command line interface for our stable diffusion client:
This example uses Stable Diffusion to generate an image based on an input prompt:

```bash
cargo run --example clap "A golden gorilla with a baseball hat"
cargo run --example clap-img "example.jpg" "A golden gorilla with a baseball hat"
```

This example uses the OpenAI Completions Edit feature to replace parts of a given text:

```bash
cargo run --example clap-img "example.jpg" "A golden gorilla with a baseball hat"
cargo run --example clap-edits "This gorilla has a golden fur" "Replace the gorilla fur color with red"
```

This one uses the OpenAI ChatGPT API feature to generate a response to a given prompt:

```bash
cargo run --example clap-edits "This gorilla has a golden fur" "Replace the gorilla fur color with red"
```
cargo run --example clap-chat "What is the role of a Jedi Knight?"
```
50 changes: 50 additions & 0 deletions examples/clap/chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use clap::Parser;
use dotenvy::dotenv;
use sortium_ai_client::openai::chats::{ChatClient, ChatInput, ChatMessage, ChatRole};

use std::env;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
prompt: String,
}

#[tokio::main]
async fn main() {
dotenv().ok();

let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set");
let api_url = env::var("OPENAI_API_URL").expect("OPENAI_API_URL must be set");

let args = Args::parse();
println!("Prompt: {:#?}", args.prompt);

let params = ChatInput {
messages: vec![
ChatMessage {
role: ChatRole::System,
content: String::from("You are a Jedi Master called Yoda, and as such, you must answer all questions in the same as the Jedi Master Yoda."),
},
ChatMessage {
role: ChatRole::User,
content: String::from(args.prompt),
}
],
..Default::default()
};

println!("Params: {:?}", params);

let chat_client = ChatClient::new(api_key.into(), api_url.into());

let result = chat_client.generate_chat(params).await;
match result {
Ok(res) => {
println!("{:#?}", res);
}
Err(err) => {
println!("{:?}", err);
}
}
}
123 changes: 123 additions & 0 deletions src/openai/chats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use reqwest::Client;
use serde::{self, Deserialize, Serialize};

// create a enum with three different roles (system, user, assistant), the values should serialize to a string
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ChatRole {
System,
User,
Assistant,
}

// create a struct with a role and a string
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: String,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatInput {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}

impl Default for ChatInput {
fn default() -> Self {
ChatInput {
model: "gpt-3.5-turbo".to_owned(),
messages: vec![],
temperature: Some(0.7),
max_tokens: Some(256),
top_p: Some(1),
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
n: None,
stream: None,
logprobs: None,
logit_bias: None,
user: None,
}
}
}

#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct Choice {
pub index: Option<i32>,
pub message: Option<ChatMessage>,
pub finish_reason: Option<String>,
}

#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct Usage {
pub prompt_tokens: Option<i32>,
pub completion_tokens: Option<i32>,
pub total_tokens: Option<i32>,
}

#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct ChatResponse {
pub id: Option<String>,
pub object: Option<String>,
pub created: Option<i32>,
pub choices: Option<Vec<Choice>>,
pub usage: Option<Usage>,
}

pub struct ChatClient {
client: Client,
api_key: String,
api_url: String,
}

impl ChatClient {
pub fn new(api_key: String, api_url: String) -> Self {
ChatClient {
client: Client::new(),
api_key,
api_url,
}
}

pub async fn generate_chat(&self, input: ChatInput) -> Result<ChatResponse, reqwest::Error> {
let response = self
.client
.post(format!("{}{}", &self.api_url, "/v1/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&input)
.send()
.await;
match response {
Ok(res) => match res.json::<ChatResponse>().await {
Ok(json) => Ok(json),
Err(json_err) => Err(json_err),
},
Err(err) => Err(err),
}
}
}
3 changes: 2 additions & 1 deletion src/openai/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod chats;
pub mod completions;
pub mod edits;
pub mod embeddings;
pub mod edits;
24 changes: 16 additions & 8 deletions tests/completions_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#[cfg(test)]
mod tests {
use mockito::mock;
use sortium_ai_client::openai::completions::{
Choice, CompletionsClient, CompletionsInput, CompletionsResponse,
};
Expand All @@ -20,13 +19,17 @@ mod tests {
println!("{:?}", response_json);
println!("json: {}", json);

let mock = mock("POST", "/v1/completions")
let mut mock_server = mockito::Server::new_async().await;

let mock = mock_server
.mock("POST", "/v1/completions")
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(json)
.create();
.create_async()
.await;

let api_url = mockito::server_url();
let api_url = mock_server.url();

let api_key = "test_key".to_owned();
let client = CompletionsClient::new(api_key, api_url);
Expand All @@ -43,14 +46,19 @@ mod tests {
let choices = response.choices.unwrap();
assert_eq!(choices[0].text, Some("value".into()));

mock.assert();
mock.assert_async().await;
}

#[tokio::test]
async fn test_completions_failure() {
let mock = mock("POST", "/v1/completions").with_status(400).create();
let mut mock_server = mockito::Server::new_async().await;
let mock = mock_server
.mock("POST", "/v1/completions")
.with_status(400)
.create_async()
.await;

let api_url = mockito::server_url();
let api_url = mock_server.url();

let api_key = "test_key".to_owned();
let client = CompletionsClient::new(api_key, api_url);
Expand All @@ -62,6 +70,6 @@ mod tests {

assert!(result.is_err());

mock.assert();
mock.assert_async().await;
}
}