Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Concurrent OpenAI Embeddings #8

Merged
63 changes: 51 additions & 12 deletions orca/src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub struct Response {
choices: Vec<Choice>,
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct OpenAIEmbeddingResponse {
object: String,
model: String,
Expand All @@ -64,7 +64,7 @@ impl Display for OpenAIEmbeddingResponse {
}
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)]
pub struct Embedding {
pub index: u32,
pub object: String,
Expand All @@ -87,7 +87,7 @@ impl Display for Response {
}
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct Usage {
prompt_tokens: i32,
completion_tokens: Option<i32>,
Expand Down Expand Up @@ -273,25 +273,56 @@ impl EmbeddingTrait for OpenAI {
Ok(res.into())
}

/// TODO: Concurrent
async fn generate_embeddings(&self, prompt: Vec<Box<dyn Prompt>>) -> Result<EmbeddingResponse> {
let mut embeddings = Vec::new();
for prompt in prompt {
async fn generate_embeddings(&self, prompts: Vec<Box<dyn Prompt>>) -> Result<EmbeddingResponse> {
let num_prompts = prompts.len();
let mut embeddings = vec![OpenAIEmbeddingResponse::default(); num_prompts];

let (sender, mut receiver) = tokio::sync::mpsc::channel(num_prompts);

for (i, prompt) in prompts.into_iter().enumerate() {
let sender = sender.clone();
let client = self.client.clone();
let req = self.generate_embedding_request(&prompt.to_string())?;
let res = self.client.execute(req).await?;
let res = res.json::<OpenAIEmbeddingResponse>().await?;
embeddings.push(res);

tokio::spawn(async move {
let result: Result<OpenAIEmbeddingResponse, String> = async {
let res = client.execute(req).await.map_err(|e| format!("Request Failed: {}", e.to_string()))?;
let response = res
.json::<OpenAIEmbeddingResponse>()
.await
.map_err(|e| format!("Mapping Error: {}", e.to_string()))?;
Ok(response)
}
.await;

// Send back the result (success or error) via the channel.
sender.send((i, result)).await.expect("Failed to send over channel");
});
}

drop(sender);

while let Some((i, result)) = receiver.recv().await {
match result {
Ok(response) => {
embeddings[i] = response;
}
Err(e) => {
return Err(anyhow::anyhow!("Failed to generate embeddings: {}", e));
}
}
}

Ok(EmbeddingResponse::OpenAI(embeddings))
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::prompt;
use crate::prompt::TemplateEngine;
use crate::template;
use crate::{prompt, prompts};
use std::collections::HashMap;

#[tokio::test]
Expand Down Expand Up @@ -322,10 +353,18 @@ mod test {
}

#[tokio::test]
async fn test_embeddings() {
async fn test_embedding() {
let client = OpenAI::new();
let content = prompt!("This is a test");
let res = client.generate_embedding(content).await.unwrap();
assert!(res.to_vec2().unwrap().len() > 0);
}

#[tokio::test]
async fn test_embeddings() {
let client = OpenAI::new();
let content = prompts!("This is a test", "This is another test", "This is a third test");
let res = client.generate_embeddings(content).await.unwrap();
assert!(res.to_vec2().unwrap().len() > 0);
}
}
Loading