Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: gpt-4o supports reading images #164

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ rig-derive = { version = "0.1.0", path = "./rig-core-derive", optional = true }
glob = "0.3.1"
lopdf = { version = "0.34.0", optional = true }
rayon = { version = "1.10.0", optional = true}
base64 = "0.22.1"

[dev-dependencies]
anyhow = "1.0.75"
Expand Down
32 changes: 32 additions & 0 deletions rig-core/examples/read_image.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use rig::{completion::Prompt, providers::openai};
use base64::{engine::general_purpose::STANDARD, Engine};
use reqwest;
use std::error::Error;

pub async fn download_image_as_base64(image_url: &str) -> Result<String, Box<dyn Error>> {
let response = reqwest::get(image_url).await?;
let image_data = response.bytes().await?;
let base64_string = STANDARD.encode(&image_data);
let data_uri = format!("data:{};base64,{}", "image/jpeg", base64_string);
Ok(data_uri)
}

#[tokio::main]
async fn main() {
// Create OpenAI client and model
let openai_client = openai::Client::from_env();
let image_base64 = download_image_as_base64("https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg").await.expect("Failed to convert image to base64");
let gpt4o = openai_client
.agent("gpt-4o")
.preamble("You are a helpful assistant.")
.image_urls(vec![image_base64])
.build();

// Prompt the model and print its response
let response = gpt4o
.prompt("What is in this image?")
.await
.expect("Failed to prompt GPT-4o");

println!("GPT-4o: {response}");
}
15 changes: 14 additions & 1 deletion rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ pub struct Agent<M: CompletionModel> {
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Actual tool implementations
pub tools: ToolSet,
/// List of image URLs to be included in completion requests
image_urls: Option<Vec<String>>,
}

impl<M: CompletionModel> Completion<M> for Agent<M> {
Expand Down Expand Up @@ -241,7 +243,8 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
.tools([static_tools.clone(), dynamic_tools].concat())
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone()))
.additional_params_opt(self.additional_params.clone())
.image_urls_opt(self.image_urls.clone()))
}
}

Expand Down Expand Up @@ -314,6 +317,8 @@ pub struct AgentBuilder<M: CompletionModel> {
temperature: Option<f64>,
/// Actual tool implementations
tools: ToolSet,
/// List of image URLs to be added to the completion request
image_urls: Option<Vec<String>>,
}

impl<M: CompletionModel> AgentBuilder<M> {
Expand All @@ -329,6 +334,7 @@ impl<M: CompletionModel> AgentBuilder<M> {
dynamic_context: vec![],
dynamic_tools: vec![],
tools: ToolSet::default(),
image_urls: None,
}
}

Expand Down Expand Up @@ -409,6 +415,12 @@ impl<M: CompletionModel> AgentBuilder<M> {
self
}

/// Add image URLs to the agent
pub fn image_urls(mut self, urls: Vec<String>) -> Self {
self.image_urls = Some(urls);
self
}

/// Build the agent
pub fn build(self) -> Agent<M> {
Agent {
Expand All @@ -422,6 +434,7 @@ impl<M: CompletionModel> AgentBuilder<M> {
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
tools: self.tools,
image_urls: self.image_urls,
}
}
}
27 changes: 27 additions & 0 deletions rig-core/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ pub struct CompletionRequest {
pub max_tokens: Option<u64>,
/// Additional provider-specific parameters to be sent to the completion model provider
pub additional_params: Option<serde_json::Value>,

/// The image urls to be sent to the completion model provider
pub image_urls: Option<Vec<String>>,
}

impl CompletionRequest {
Expand Down Expand Up @@ -337,6 +340,7 @@ pub struct CompletionRequestBuilder<M: CompletionModel> {
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<serde_json::Value>,
image_urls: Option<Vec<String>>,
}

impl<M: CompletionModel> CompletionRequestBuilder<M> {
Expand All @@ -351,6 +355,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
temperature: None,
max_tokens: None,
additional_params: None,
image_urls: None,
}
}

Expand Down Expand Up @@ -452,6 +457,26 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
self
}

/// Adds an image URL to the completion request.
pub fn image_url(mut self, url: String) -> Self {
match &mut self.image_urls {
Some(urls) => urls.push(url),
None => self.image_urls = Some(vec![url]),
}
self
}

/// Adds a list of image URLs to the completion request.
pub fn image_urls(self, urls: Vec<String>) -> Self {
urls.into_iter().fold(self, |builder, url| builder.image_url(url))
}

/// Sets the image URLs for the completion request.
pub fn image_urls_opt(mut self, urls: Option<Vec<String>>) -> Self {
self.image_urls = urls;
self
}

/// Builds the completion request.
pub fn build(self) -> CompletionRequest {
CompletionRequest {
Expand All @@ -463,6 +488,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
image_urls: self.image_urls,
}
}

Expand Down Expand Up @@ -533,6 +559,7 @@ mod tests {
temperature: None,
max_tokens: None,
additional_params: None,
image_urls: None,
};

let expected = concat!(
Expand Down
110 changes: 93 additions & 17 deletions rig-core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionRe
match value.choices.as_slice() {
[Choice {
message:
Message {
Message {
tool_calls: Some(calls),
..
},
Expand All @@ -397,13 +397,18 @@ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionRe
}
[Choice {
message:
Message {
Message {
content: Some(content),
..
},
..
}, ..] => Ok(completion::CompletionResponse {
choice: completion::ModelChoice::Message(content.to_string()),
choice: completion::ModelChoice::Message(
content.iter()
.filter_map(|item| item.text.clone())
.collect::<Vec<_>>()
.join("")
),
raw_response: value,
}),
_ => Err(CompletionError::ResponseError(
Expand All @@ -421,14 +426,56 @@ pub struct Choice {
pub finish_reason: String,
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Serialize)]
pub struct ContentItem {
#[serde(rename = "type")]
pub content_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_url: Option<ImageUrl>,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct ImageUrl {
pub url: String,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct Message {
pub role: String,
pub content: Option<String>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_content")]
pub content: Option<Vec<ContentItem>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}

#[derive(Debug, Deserialize)]
// Add this function to handle both string and array content formats
fn deserialize_content<'de, D>(deserializer: D) -> Result<Option<Vec<ContentItem>>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum ContentWrapper {
String(String),
Array(Vec<ContentItem>),
}

let content = Option::<ContentWrapper>::deserialize(deserializer)?;
match content {
Some(ContentWrapper::String(s)) => Ok(Some(vec![ContentItem {
content_type: "text".to_string(),
text: Some(s),
image_url: None,
}])),
Some(ContentWrapper::Array(items)) => Ok(Some(items)),
None => Ok(None),
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct ToolCall {
pub id: String,
pub r#type: String,
Expand All @@ -450,7 +497,7 @@ impl From<completion::ToolDefinition> for ToolDefinition {
}
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Serialize)]
pub struct Function {
pub name: String,
pub arguments: String,
Expand All @@ -477,28 +524,57 @@ impl completion::CompletionModel for CompletionModel {

async fn completion(
&self,
mut completion_request: CompletionRequest,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history = if let Some(preamble) = &completion_request.preamble {
vec![completion::Message {
vec![Message {
role: "system".into(),
content: preamble.clone(),
content: Some(vec![ContentItem {
content_type: "text".to_string(),
text: Some(preamble.clone()),
image_url: None,
}]),
tool_calls: None,
}]
} else {
vec![]
};

// Extend existing chat history
full_history.append(&mut completion_request.chat_history);

// Add context documents to chat history
let prompt_with_context = completion_request.prompt_with_context();
full_history.extend(completion_request.chat_history.clone().into_iter().map(|msg| Message {
role: msg.role,
content: Some(vec![ContentItem {
content_type: "text".to_string(),
text: Some(msg.content),
image_url: None,
}]),
tool_calls: None,
}));

// Create final message content
let mut content = vec![ContentItem {
content_type: "text".to_string(),
text: Some(completion_request.prompt_with_context()),
image_url: None,
}];

// Add image URLs if present
if let Some(urls) = completion_request.image_urls {
for url in urls {
content.push(ContentItem {
content_type: "image_url".to_string(),
text: None,
image_url: Some(ImageUrl { url }),
});
}
}

// Add context documents to chat history
full_history.push(completion::Message {
// Add final message
full_history.push(Message {
role: "user".into(),
content: prompt_with_context,
content: Some(content),
tool_calls: None,
});

let request = if completion_request.tools.is_empty() {
Expand Down