From 6ed629a5953db324e242c03834b96866e108aeab Mon Sep 17 00:00:00 2001 From: Santiago Medina Date: Mon, 6 Nov 2023 08:31:11 -0800 Subject: [PATCH] openai embeddings init vector with size --- orca/src/llm/openai.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/orca/src/llm/openai.rs b/orca/src/llm/openai.rs index 1d20966..7bb37f9 100644 --- a/orca/src/llm/openai.rs +++ b/orca/src/llm/openai.rs @@ -39,7 +39,7 @@ pub struct Response { choices: Vec, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default, Clone)] pub struct OpenAIEmbeddingResponse { object: String, model: String, @@ -64,7 +64,7 @@ impl Display for OpenAIEmbeddingResponse { } } -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)] pub struct Embedding { pub index: u32, pub object: String, @@ -87,7 +87,7 @@ impl Display for Response { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default, Clone)] pub struct Usage { prompt_tokens: i32, completion_tokens: Option, @@ -274,11 +274,10 @@ impl EmbeddingTrait for OpenAI { } async fn generate_embeddings(&self, prompts: Vec>) -> Result { - let mut embeddings = Vec::with_capacity(prompts.len()); - - let (sender, mut receiver) = tokio::sync::mpsc::channel(prompts.len()); - let num_prompts = prompts.len(); + let mut embeddings = vec![OpenAIEmbeddingResponse::default(); num_prompts]; + + let (sender, mut receiver) = tokio::sync::mpsc::channel(num_prompts); for (i, prompt) in prompts.into_iter().enumerate() { let sender = sender.clone(); @@ -301,10 +300,12 @@ impl EmbeddingTrait for OpenAI { }); } + drop(sender); + while let Some((i, result)) = receiver.recv().await { match result { Ok(response) => { - embeddings.push(response); + embeddings[i] = response; } Err(e) => { return Err(anyhow::anyhow!("Failed to generate embeddings: {}", e));