diff --git a/Cargo.lock b/Cargo.lock index 6c86bb4c..37c8c3b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4809,6 +4809,7 @@ version = "0.6.0" dependencies = [ "anyhow", "assert_fs", + "base64 0.22.1", "futures", "glob", "lopdf", diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index afb3c1d9..cf86bec7 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -27,6 +27,7 @@ rig-derive = { version = "0.1.0", path = "./rig-core-derive", optional = true } glob = "0.3.1" lopdf = { version = "0.34.0", optional = true } rayon = { version = "1.10.0", optional = true} +base64 = "0.22.1" [dev-dependencies] anyhow = "1.0.75" diff --git a/rig-core/examples/read_image.rs b/rig-core/examples/read_image.rs new file mode 100644 index 00000000..845574af --- /dev/null +++ b/rig-core/examples/read_image.rs @@ -0,0 +1,32 @@ +use rig::{completion::Prompt, providers::openai}; +use base64::{engine::general_purpose::STANDARD, Engine}; +use reqwest; +use std::error::Error; + +pub async fn download_image_as_base64(image_url: &str) -> Result> { + let response = reqwest::get(image_url).await?; + let image_data = response.bytes().await?; + let base64_string = STANDARD.encode(&image_data); + let data_uri = format!("data:{};base64,{}", "image/jpeg", base64_string); + Ok(data_uri) +} + +#[tokio::main] +async fn main() { + // Create OpenAI client and model + let openai_client = openai::Client::from_env(); + let image_base64 = download_image_as_base64("https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg").await.expect("Failed to convert image to base64"); + let gpt4o = openai_client + .agent("gpt-4o") + .preamble("You are a helpful assistant.") + .image_urls(vec![image_base64]) + .build(); + + // Prompt the model and print its response + let response = gpt4o + .prompt("What is in this image?") + .await + .expect("Failed to prompt GPT-4o"); + + println!("GPT-4o: {response}"); +} diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 5cbac1cc..d3c1d5a6 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -160,6 +160,8 @@ pub struct Agent { dynamic_tools: Vec<(usize, Box)>, /// Actual tool implementations pub tools: ToolSet, + /// List of image URLs to be included in completion requests + image_urls: Option>, } impl Completion for Agent { @@ -241,7 +243,8 @@ impl Completion for Agent { .tools([static_tools.clone(), dynamic_tools].concat()) .temperature_opt(self.temperature) .max_tokens_opt(self.max_tokens) - .additional_params_opt(self.additional_params.clone())) + .additional_params_opt(self.additional_params.clone()) + .image_urls_opt(self.image_urls.clone())) } } @@ -314,6 +317,8 @@ pub struct AgentBuilder { temperature: Option, /// Actual tool implementations tools: ToolSet, + /// List of image URLs to be added to the completion request + image_urls: Option>, } impl AgentBuilder { @@ -329,6 +334,7 @@ impl AgentBuilder { dynamic_context: vec![], dynamic_tools: vec![], tools: ToolSet::default(), + image_urls: None, } } @@ -409,6 +415,12 @@ impl AgentBuilder { self } + /// Add image URLs to the agent + pub fn image_urls(mut self, urls: Vec) -> Self { + self.image_urls = Some(urls); + self + } + /// Build the agent pub fn build(self) -> Agent { Agent { @@ -422,6 +434,7 @@ impl AgentBuilder { dynamic_context: self.dynamic_context, dynamic_tools: self.dynamic_tools, tools: self.tools, + image_urls: self.image_urls, } } } diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index f13f316b..23cef3ac 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -263,6 +263,9 @@ pub struct CompletionRequest { pub max_tokens: Option, /// Additional provider-specific parameters to be sent to the completion model provider pub additional_params: Option, + + /// The image urls to be sent to the completion model provider + pub image_urls: Option>, } impl CompletionRequest { @@ -337,6 +340,7 @@ pub struct CompletionRequestBuilder { temperature: Option, max_tokens: Option, additional_params: Option, + image_urls: Option>, } impl CompletionRequestBuilder { @@ -351,6 +355,7 @@ impl CompletionRequestBuilder { temperature: None, max_tokens: None, additional_params: None, + image_urls: None, } } @@ -452,6 +457,26 @@ impl CompletionRequestBuilder { self } + /// Adds an image URL to the completion request. + pub fn image_url(mut self, url: String) -> Self { + match &mut self.image_urls { + Some(urls) => urls.push(url), + None => self.image_urls = Some(vec![url]), + } + self + } + + /// Adds a list of image URLs to the completion request. + pub fn image_urls(self, urls: Vec) -> Self { + urls.into_iter().fold(self, |builder, url| builder.image_url(url)) + } + + /// Sets the image URLs for the completion request. + pub fn image_urls_opt(mut self, urls: Option>) -> Self { + self.image_urls = urls; + self + } + /// Builds the completion request. pub fn build(self) -> CompletionRequest { CompletionRequest { @@ -463,6 +488,7 @@ impl CompletionRequestBuilder { temperature: self.temperature, max_tokens: self.max_tokens, additional_params: self.additional_params, + image_urls: self.image_urls, } } @@ -533,6 +559,7 @@ mod tests { temperature: None, max_tokens: None, additional_params: None, + image_urls: None, }; let expected = concat!( diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 27cb3b08..c269ca5a 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -377,7 +377,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse Ok(completion::CompletionResponse { - choice: completion::ModelChoice::Message(content.to_string()), + choice: completion::ModelChoice::Message( + content.iter() + .filter_map(|item| item.text.clone()) + .collect::>() + .join("") + ), raw_response: value, }), _ => Err(CompletionError::ResponseError( @@ -421,14 +426,56 @@ pub struct Choice { pub finish_reason: String, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] +pub struct ContentItem { + #[serde(rename = "type")] + pub content_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_url: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ImageUrl { + pub url: String, +} + +#[derive(Debug, Deserialize, Serialize)] pub struct Message { pub role: String, - pub content: Option, + #[serde(default)] + #[serde(deserialize_with = "deserialize_content")] + pub content: Option>, + #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, } -#[derive(Debug, Deserialize)] +// Add this function to handle both string and array content formats +fn deserialize_content<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: serde::Deserializer<'de>, +{ + #[derive(Deserialize)] + #[serde(untagged)] + enum ContentWrapper { + String(String), + Array(Vec), + } + + let content = Option::::deserialize(deserializer)?; + match content { + Some(ContentWrapper::String(s)) => Ok(Some(vec![ContentItem { + content_type: "text".to_string(), + text: Some(s), + image_url: None, + }])), + Some(ContentWrapper::Array(items)) => Ok(Some(items)), + None => Ok(None), + } +} + +#[derive(Debug, Deserialize, Serialize)] pub struct ToolCall { pub id: String, pub r#type: String, @@ -450,7 +497,7 @@ impl From for ToolDefinition { } } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct Function { pub name: String, pub arguments: String, @@ -477,28 +524,57 @@ impl completion::CompletionModel for CompletionModel { async fn completion( &self, - mut completion_request: CompletionRequest, + completion_request: CompletionRequest, ) -> Result, CompletionError> { // Add preamble to chat history (if available) let mut full_history = if let Some(preamble) = &completion_request.preamble { - vec![completion::Message { + vec![Message { role: "system".into(), - content: preamble.clone(), + content: Some(vec![ContentItem { + content_type: "text".to_string(), + text: Some(preamble.clone()), + image_url: None, + }]), + tool_calls: None, }] } else { vec![] }; // Extend existing chat history - full_history.append(&mut completion_request.chat_history); - - // Add context documents to chat history - let prompt_with_context = completion_request.prompt_with_context(); + full_history.extend(completion_request.chat_history.clone().into_iter().map(|msg| Message { + role: msg.role, + content: Some(vec![ContentItem { + content_type: "text".to_string(), + text: Some(msg.content), + image_url: None, + }]), + tool_calls: None, + })); + + // Create final message content + let mut content = vec![ContentItem { + content_type: "text".to_string(), + text: Some(completion_request.prompt_with_context()), + image_url: None, + }]; + + // Add image URLs if present + if let Some(urls) = completion_request.image_urls { + for url in urls { + content.push(ContentItem { + content_type: "image_url".to_string(), + text: None, + image_url: Some(ImageUrl { url }), + }); + } + } - // Add context documents to chat history - full_history.push(completion::Message { + // Add final message + full_history.push(Message { role: "user".into(), - content: prompt_with_context, + content: Some(content), + tool_calls: None, }); let request = if completion_request.tools.is_empty() {