diff --git a/orca/src/llm/openai.rs b/orca/src/llm/openai.rs index cf60a64..7bb37f9 100644 --- a/orca/src/llm/openai.rs +++ b/orca/src/llm/openai.rs @@ -39,7 +39,7 @@ pub struct Response { choices: Vec, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default, Clone)] pub struct OpenAIEmbeddingResponse { object: String, model: String, @@ -64,7 +64,7 @@ impl Display for OpenAIEmbeddingResponse { } } -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)] pub struct Embedding { pub index: u32, pub object: String, @@ -87,7 +87,7 @@ impl Display for Response { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default, Clone)] pub struct Usage { prompt_tokens: i32, completion_tokens: Option, @@ -273,15 +273,46 @@ impl EmbeddingTrait for OpenAI { Ok(res.into()) } - /// TODO: Concurrent - async fn generate_embeddings(&self, prompt: Vec>) -> Result { - let mut embeddings = Vec::new(); - for prompt in prompt { + async fn generate_embeddings(&self, prompts: Vec>) -> Result { + let num_prompts = prompts.len(); + let mut embeddings = vec![OpenAIEmbeddingResponse::default(); num_prompts]; + + let (sender, mut receiver) = tokio::sync::mpsc::channel(num_prompts); + + for (i, prompt) in prompts.into_iter().enumerate() { + let sender = sender.clone(); + let client = self.client.clone(); let req = self.generate_embedding_request(&prompt.to_string())?; - let res = self.client.execute(req).await?; - let res = res.json::().await?; - embeddings.push(res); + + tokio::spawn(async move { + let result: Result = async { + let res = client.execute(req).await.map_err(|e| format!("Request Failed: {}", e.to_string()))?; + let response = res + .json::() + .await + .map_err(|e| format!("Mapping Error: {}", e.to_string()))?; + Ok(response) + } + .await; + + // Send back the result (success or error) via the channel. + sender.send((i, result)).await.expect("Failed to send over channel"); + }); + } + + drop(sender); + + while let Some((i, result)) = receiver.recv().await { + match result { + Ok(response) => { + embeddings[i] = response; + } + Err(e) => { + return Err(anyhow::anyhow!("Failed to generate embeddings: {}", e)); + } + } } + Ok(EmbeddingResponse::OpenAI(embeddings)) } } @@ -289,9 +320,9 @@ impl EmbeddingTrait for OpenAI { #[cfg(test)] mod test { use super::*; - use crate::prompt; use crate::prompt::TemplateEngine; use crate::template; + use crate::{prompt, prompts}; use std::collections::HashMap; #[tokio::test] @@ -322,10 +353,18 @@ mod test { } #[tokio::test] - async fn test_embeddings() { + async fn test_embedding() { let client = OpenAI::new(); let content = prompt!("This is a test"); let res = client.generate_embedding(content).await.unwrap(); assert!(res.to_vec2().unwrap().len() > 0); } + + #[tokio::test] + async fn test_embeddings() { + let client = OpenAI::new(); + let content = prompts!("This is a test", "This is another test", "This is a third test"); + let res = client.generate_embeddings(content).await.unwrap(); + assert!(res.to_vec2().unwrap().len() > 0); + } }