Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Concurrent OpenAI Embeddings #8

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions orca/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ candle-nn = { git = "https://github.com/huggingface/candle" }
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.17"
log = "0.4.20"
crossbeam-channel = "0.5.8"

[dev-dependencies]
base64 = "0.21.4"
Expand Down
40 changes: 32 additions & 8 deletions orca/src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
prompt::{chat::Message, Prompt},
};
use anyhow::Result;
use crossbeam_channel::bounded;
use reqwest::Client;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -273,25 +274,40 @@ impl EmbeddingTrait for OpenAI {
Ok(res.into())
}

/// TODO: Concurrent
async fn generate_embeddings(&self, prompt: Vec<Box<dyn Prompt>>) -> Result<EmbeddingResponse> {
async fn generate_embeddings(&self, prompts: Vec<Box<dyn Prompt>>) -> Result<EmbeddingResponse> {
let mut embeddings = Vec::new();
Butch78 marked this conversation as resolved.
Show resolved Hide resolved
for prompt in prompt {
let (sender, receiver) = bounded(prompts.len());
Butch78 marked this conversation as resolved.
Show resolved Hide resolved

let num_prompts = prompts.len();

for (i, prompt) in prompts.into_iter().enumerate() {
let sender = sender.clone();
let client = self.client.clone();
let req = self.generate_embedding_request(&prompt.to_string())?;
let res = self.client.execute(req).await?;
let res = res.json::<OpenAIEmbeddingResponse>().await?;
embeddings.push(res);

tokio::spawn(async move {
let res = client.execute(req).await.map_err(|e| e.to_string())?;
let response = res.json::<OpenAIEmbeddingResponse>().await.map_err(|e| e.to_string())?;
sender.send((i, response)).unwrap();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could look something like sender.send((i, response)).await; using tokio channels

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: Actually, the current method of error handling within the spawned tasks can be problematic. It might be better to send back a Result type over the channel, so the receiver can decide how to handle errors.

tokio::spawn(async move {
    let result = async {
        let res = client.execute(req).await.map_err(|e| EmbeddingError::RequestError(e.to_string()))?;
        let response = res.json::<OpenAIEmbeddingResponse>().await.map_err(|e| EmbeddingError::ResponseError(e.to_string()))?;
        Ok(response)
    }.await;

    // Send back the result (success or error) via the channel.
    sender.send((i, result)).await.expect("Failed to send over channel");
});

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added simple format! errors but I was wondering did you want to create some thiserror for the EmbeddingError like in your example?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format! should be fine for now! I just made up some generic type there for example purposes. Tbh Rust error handling still confuses me, so I just try to use anyhow for everything. In the future my plan is to clean up Orca's error handling.

Ok::<_, String>(())
});
}

for _ in 0..num_prompts {
let (i, res) = receiver.recv().unwrap();
embeddings[i] = res;
}
Butch78 marked this conversation as resolved.
Show resolved Hide resolved

Ok(EmbeddingResponse::OpenAI(embeddings))
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::prompt;
use crate::prompt::TemplateEngine;
use crate::template;
use crate::{prompt, prompts};
use std::collections::HashMap;

#[tokio::test]
Expand Down Expand Up @@ -322,10 +338,18 @@ mod test {
}

#[tokio::test]
async fn test_embeddings() {
async fn test_embedding() {
let client = OpenAI::new();
let content = prompt!("This is a test");
let res = client.generate_embedding(content).await.unwrap();
assert!(res.to_vec2().unwrap().len() > 0);
}

#[tokio::test]
async fn test_embeddings() {
let client = OpenAI::new();
let content = prompts!("This is a test", "This is another test", "This is a third test");
let res = client.generate_embeddings(content).await.unwrap();
assert!(res.to_vec2().unwrap().len() > 0);
}
}
Loading