Skip to content

Commit

Permalink
openai embeddings init vector with size
Browse files Browse the repository at this point in the history
  • Loading branch information
santiagomed committed Nov 6, 2023
1 parent e214d7c commit 6ed629a
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions orca/src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub struct Response {
choices: Vec<Choice>,
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct OpenAIEmbeddingResponse {
object: String,
model: String,
Expand All @@ -64,7 +64,7 @@ impl Display for OpenAIEmbeddingResponse {
}
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Default)]
pub struct Embedding {
pub index: u32,
pub object: String,
Expand All @@ -87,7 +87,7 @@ impl Display for Response {
}
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct Usage {
prompt_tokens: i32,
completion_tokens: Option<i32>,
Expand Down Expand Up @@ -274,11 +274,10 @@ impl EmbeddingTrait for OpenAI {
}

async fn generate_embeddings(&self, prompts: Vec<Box<dyn Prompt>>) -> Result<EmbeddingResponse> {
let mut embeddings = Vec::with_capacity(prompts.len());

let (sender, mut receiver) = tokio::sync::mpsc::channel(prompts.len());

let num_prompts = prompts.len();
let mut embeddings = vec![OpenAIEmbeddingResponse::default(); num_prompts];

let (sender, mut receiver) = tokio::sync::mpsc::channel(num_prompts);

for (i, prompt) in prompts.into_iter().enumerate() {
let sender = sender.clone();
Expand All @@ -301,10 +300,12 @@ impl EmbeddingTrait for OpenAI {
});
}

drop(sender);

while let Some((i, result)) = receiver.recv().await {
match result {
Ok(response) => {
embeddings.push(response);
embeddings[i] = response;
}
Err(e) => {
return Err(anyhow::anyhow!("Failed to generate embeddings: {}", e));
Expand Down

0 comments on commit 6ed629a

Please sign in to comment.