Skip to content

Commit

Permalink
Merge pull request #8 from Butch78/feat/concurrent-generate-embeddings
Browse files Browse the repository at this point in the history
feat: Concurrent OpenAI Embeddings
  • Loading branch information
santiagomed authored Nov 7, 2023
2 parents d399132 + 6ed629a commit d3b42c7
Showing 1 changed file with 51 additions and 12 deletions.
63 changes: 51 additions & 12 deletions orca/src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub struct Response {
choices: Vec<Choice>,
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct OpenAIEmbeddingResponse {
object: String,
model: String,
Expand All @@ -64,7 +64,7 @@ impl Display for OpenAIEmbeddingResponse {
}
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)]
pub struct Embedding {
pub index: u32,
pub object: String,
Expand All @@ -87,7 +87,7 @@ impl Display for Response {
}
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct Usage {
prompt_tokens: i32,
completion_tokens: Option<i32>,
Expand Down Expand Up @@ -273,25 +273,56 @@ impl EmbeddingTrait for OpenAI {
Ok(res.into())
}

/// TODO: Concurrent
async fn generate_embeddings(&self, prompt: Vec<Box<dyn Prompt>>) -> Result<EmbeddingResponse> {
let mut embeddings = Vec::new();
for prompt in prompt {
async fn generate_embeddings(&self, prompts: Vec<Box<dyn Prompt>>) -> Result<EmbeddingResponse> {
let num_prompts = prompts.len();
let mut embeddings = vec![OpenAIEmbeddingResponse::default(); num_prompts];

let (sender, mut receiver) = tokio::sync::mpsc::channel(num_prompts);

for (i, prompt) in prompts.into_iter().enumerate() {
let sender = sender.clone();
let client = self.client.clone();
let req = self.generate_embedding_request(&prompt.to_string())?;
let res = self.client.execute(req).await?;
let res = res.json::<OpenAIEmbeddingResponse>().await?;
embeddings.push(res);

tokio::spawn(async move {
let result: Result<OpenAIEmbeddingResponse, String> = async {
let res = client.execute(req).await.map_err(|e| format!("Request Failed: {}", e.to_string()))?;
let response = res
.json::<OpenAIEmbeddingResponse>()
.await
.map_err(|e| format!("Mapping Error: {}", e.to_string()))?;
Ok(response)
}
.await;

// Send back the result (success or error) via the channel.
sender.send((i, result)).await.expect("Failed to send over channel");
});
}

drop(sender);

while let Some((i, result)) = receiver.recv().await {
match result {
Ok(response) => {
embeddings[i] = response;
}
Err(e) => {
return Err(anyhow::anyhow!("Failed to generate embeddings: {}", e));
}
}
}

Ok(EmbeddingResponse::OpenAI(embeddings))
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::prompt;
use crate::prompt::TemplateEngine;
use crate::template;
use crate::{prompt, prompts};
use std::collections::HashMap;

#[tokio::test]
Expand Down Expand Up @@ -322,10 +353,18 @@ mod test {
}

#[tokio::test]
async fn test_embeddings() {
async fn test_embedding() {
let client = OpenAI::new();
let content = prompt!("This is a test");
let res = client.generate_embedding(content).await.unwrap();
assert!(res.to_vec2().unwrap().len() > 0);
}

#[tokio::test]
async fn test_embeddings() {
let client = OpenAI::new();
let content = prompts!("This is a test", "This is another test", "This is a third test");
let res = client.generate_embeddings(content).await.unwrap();
assert!(res.to_vec2().unwrap().len() > 0);
}
}

0 comments on commit d3b42c7

Please sign in to comment.