Skip to content

Commit

Permalink
parallelize bert embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
santiagomed committed Nov 4, 2023
1 parent 8efaf5c commit 07b1302
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 113 deletions.
102 changes: 44 additions & 58 deletions examples/pdf/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
#![allow(unused_imports)]
#![allow(unused_variables)]
#![allow(dead_code)]
use std::collections::HashMap;

use anyhow::Result;
use clap::Parser;
use orca::llm::bert::Bert;
use orca::llm::openai::OpenAI;
use orca::llm::quantized::Quantized;
use orca::llm::Embedding;
use orca::pipeline::simple::LLMPipeline;
use orca::pipeline::Pipeline;
use orca::prompt::context::Context;
use orca::qdrant::Qdrant;
use orca::qdrant::Value;
use orca::record::pdf;
use orca::record::pdf::Pdf;
use orca::record::Spin;
use orca::{prompt, prompts};
use orca::{
llm::{bert::Bert, quantized::Quantized, Embedding},
pipeline::simple::LLMPipeline,
pipeline::Pipeline,
prompt,
prompt::context::Context,
prompts,
qdrant::Qdrant,
record::{pdf::Pdf, Spin},
};
use serde_json::json;

#[derive(Parser, Debug)]
Expand All @@ -27,53 +18,35 @@ struct Args {
/// The path to the PDF file to index
file: String,

#[clap(long)]
/// The name of the collection to create
/// (default: the name of the file)
collection: Option<String>,

#[clap(long)]
/// The prompt to use to query the index
prompt: String,
}

#[tokio::main]
async fn main() -> Result<()> {
async fn main() {
let args = Args::parse();

// init logger
env_logger::init();

let collection = if let Some(col) = args.collection {
col
} else {
args.file.split("/").last().unwrap().split(".").next().unwrap().to_string()
};
let pdf_records = Pdf::from_file(&args.file, false).spin().unwrap().split(399);
let bert = Bert::new().build_model_and_tokenizer().await.unwrap();

let pdf_records = Pdf::from_file(&args.file, false).spin()?.split(399);
let bert = Bert::new().build_model_and_tokenizer().await?;
let collection = std::path::Path::new(&args.file)
.file_stem()
.and_then(|name| name.to_str())
.unwrap_or("default_collection")
.to_string();

let qdrant = Qdrant::new("localhost", 6334);
let qdrant = Qdrant::new("http://localhost:6334");
if qdrant.create_collection(&collection, 384).await.is_ok() {
let embeddings = bert.generate_embeddings(prompts!(&pdf_records)).await?;
qdrant.insert_many(&collection, embeddings.to_vec2()?, pdf_records).await?;
let embeddings = bert.generate_embeddings(prompts!(&pdf_records)).await.unwrap();
qdrant.insert_many(&collection, embeddings.to_vec2().unwrap(), pdf_records).await.unwrap();
}

let query_embedding = bert.generate_embedding(prompt!(args.prompt)).await?;
let result = qdrant.search(&collection, query_embedding.to_vec()?.clone(), 5, None).await?;

let context = json!({
"user_prompt": args.prompt,
"payloads": result
.iter()
.filter_map(|found_point| {
found_point.payload.as_ref().map(|payload| {
// Assuming you want to convert the whole payload to a JSON string
serde_json::to_string(payload).unwrap_or_else(|_| "{}".to_string())
})
})
.collect::<Vec<String>>()
});
let query_embedding = bert.generate_embedding(prompt!(args.prompt)).await.unwrap();
let result = qdrant.search(&collection, query_embedding.to_vec().unwrap().clone(), 5, None).await.unwrap();

let prompt_for_model = r#"
{{#chat}}
Expand All @@ -93,17 +66,30 @@ async fn main() -> Result<()> {
{{/chat}}
"#;

let openai = Quantized::new()
let context = json!({
"user_prompt": args.prompt,
"payloads": result
.iter()
.filter_map(|found_point| {
found_point.payload.as_ref().map(|payload| {
// Assuming you want to convert the whole payload to a JSON string
serde_json::to_string(payload).unwrap_or_else(|_| "{}".to_string())
})
})
.collect::<Vec<String>>()
});

let mistral = Quantized::new()
.with_model(orca::llm::quantized::Model::Mistral7bInstruct)
.with_sample_len(7500)
.load_model_from_path("../../models/mistral-7b-instruct-v0.1.Q4_K_S.gguf")?
.build_model()?;
let mut pipe = LLMPipeline::new(&openai).with_template("query", prompt_for_model);
pipe.load_context(&Context::new(context)?).await;
.load_model_from_path("../../models/mistral-7b-instruct-v0.1.Q4_K_S.gguf")
.unwrap()
.build_model()
.unwrap();
let mut pipe = LLMPipeline::new(&mistral).with_template("query", prompt_for_model);
pipe.load_context(&Context::new(context).unwrap()).await;

let response = pipe.execute("query").await?;
let response = pipe.execute("query").await.unwrap();

println!("Response: {}", response.content());

Ok(())
}
2 changes: 2 additions & 0 deletions orca/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ candle-nn = { git = "https://github.com/huggingface/candle" }
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.17"
log = "0.4.20"
rayon = "1.8.0"
env_logger = "0.10.0"

[dev-dependencies]
base64 = "0.21.4"
Expand Down
77 changes: 40 additions & 37 deletions orca/src/llm/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use candle_core::Tensor;
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use hf_hub::{api::tokio::Api, Cache, Repo, RepoType};
use std::sync::Arc;
use rayon::prelude::*;
use std::sync::{Arc, Mutex};
use tokenizers::{PaddingParams, Tokenizer};
use tokio::sync::RwLock;

Expand Down Expand Up @@ -173,9 +174,12 @@ impl Embedding for Bert {
let tokenizer = tokenizer.with_padding(None).with_truncation(None).map_err(E::msg)?;
let tokens = tokenizer.encode(prompt, true).map_err(E::msg)?.get_ids().to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
log::info!("token_ids shape: {:?}", token_ids.shape());
let token_type_ids = token_ids.zeros_like()?;
log::info!("running inference {:?}", token_ids.shape());
let start = std::time::Instant::now();
let embedding = model.forward(&token_ids, &token_type_ids)?;
log::info!("embedding shape: {:?}", embedding.shape());
log::info!("Embedding took {:?} to generate", start.elapsed());
Ok(EmbeddingResponse::Bert(embedding))
}
Expand All @@ -197,8 +201,9 @@ impl Embedding for Bert {
None
};

let model = self.model.as_ref().unwrap().clone();
let mut tokenizer = self.tokenizer.as_ref().unwrap().write().await;
let model: Arc<BertModel> = self.model.as_ref().unwrap().clone();
let mut tokenizer: tokio::sync::RwLockWriteGuard<'_, Tokenizer> =
self.tokenizer.as_ref().unwrap().write().await;
let device = &model.device;

if let Some(pp) = tokenizer.get_padding_mut() {
Expand All @@ -216,48 +221,46 @@ impl Embedding for Bert {
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
.enumerate()
.map(|(i, tokens)| {
let tokens = tokens.get_ids().to_vec();
Ok(Tensor::new(tokens.as_slice(), device)?)
let tensor = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
Ok((i, tensor))
})
.collect::<Result<Vec<_>>>()?;

let token_ids = Tensor::stack(&token_ids, 0)?;
let token_type_ids = token_ids.zeros_like()?;
log::info!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids, &token_type_ids)?;
log::info!("generated embeddings {:?}", embeddings.shape());
let embeddings = vec![Tensor::ones((2, 3), candle_core::DType::F32, &device)?; token_ids.len()];
// Wrap the embeddings vector in an Arc<Mutex<_>> for thread-safe access
let embeddings_arc = Arc::new(Mutex::new(embeddings));

// Use rayon to compute embeddings in parallel
log::info!("Computing embeddings");
let start = std::time::Instant::now();
token_ids.par_iter().try_for_each_with(embeddings_arc.clone(), |embeddings_arc, (i, token_ids)| {
let token_type_ids = token_ids.zeros_like()?;
let embedding = model.forward(token_ids, &token_type_ids)?.squeeze(0)?;

// Lock the mutex and write the embedding to the correct index
let mut embeddings = embeddings_arc.lock().map_err(|e| anyhow!("Mutex error: {}", e))?;
embeddings[*i] = embedding;

Ok::<(), anyhow::Error>(())
})?;
log::info!("Done computing embeddings");
log::info!("Embeddings took {:?} to generate", start.elapsed());

Ok(EmbeddingResponse::Bert(embeddings))
// Retrieve the final ordered embeddings
let embeddings_arc = Arc::try_unwrap(embeddings_arc)
.map_err(|_| anyhow!("Arc unwrap failed"))?
.into_inner()
.map_err(|e| anyhow!("Mutex error: {}", e))?;

let stacked_embeddings = Tensor::stack(&embeddings_arc, 0)?;

Ok(EmbeddingResponse::Bert(stacked_embeddings))
}
}
/*
let tokenizer = tokenizer.with_padding(None).with_truncation(None).map_err(E::msg)?;

let tasks: Vec<_> = prompts
.par_iter() // Using Rayon's parallel iterator
.filter_map(|p| {
let model = model.clone();
tokenizer
.encode(p.to_string(), true)
.ok()
.and_then(|t| t.get_ids().to_vec().into())
.and_then(|tokens| Tensor::new(&tokens[..], device).ok()?.unsqueeze(0).ok())
.and_then(|token_ids| {
let token_type_ids = token_ids.zeros_like().ok()?;
let start = std::time::Instant::now();
let tensor = Some(model.forward(&token_ids, &token_type_ids));
log::info!("Embeddings took {:?} to generate", start.elapsed());
tensor
})
})
.collect();
let mut embeddings = Vec::new();
for task in tasks {
embeddings.push(task?);
}
*/
#[cfg(test)]
mod test {
use super::*;
Expand Down
6 changes: 3 additions & 3 deletions orca/src/prompt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,10 @@ macro_rules! prompt {
#[macro_export]
/// takes in a vector or a series of prompts
macro_rules! prompts {
($records:expr) => {{
$records
($e:expr) => {{
$e
.into_iter()
.map(|record| Box::new(record.clone()) as Box<dyn orca::prompt::Prompt>)
.map(|x| Box::new(x.clone()) as Box<dyn orca::prompt::Prompt>)
.collect::<Vec<Box<dyn orca::prompt::Prompt>>>()
}};
($($e:expr),* $(,)?) => {
Expand Down
29 changes: 14 additions & 15 deletions orca/src/qdrant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ impl Qdrant {
/// ```
/// use orca::qdrant::Qdrant;
///
/// let qdrant = Qdrant::new("127.0.0.1", 6333);
/// let client = Qdrant::new("http://localhost:6334");
/// ```
pub fn new(host: &str, port: u16) -> Self {
let config = QdrantClientConfig::from_url(&format!("http://{}:{}", host, port));
pub fn new(url: &str) -> Self {
let config = QdrantClientConfig::from_url(url);
let client = QdrantClient::new(Some(config)).unwrap();
Qdrant { client }
}
Expand All @@ -119,7 +119,7 @@ impl Qdrant {
/// # use orca::qdrant::Qdrant;
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let client = Qdrant::new("127.0.0.1", 6333);
/// let client = Qdrant::new("http://localhost:6334");
/// let collection_name = "test_collection";
/// let vector_size = 128;
/// client.create_collection(collection_name, vector_size).await?;
Expand Down Expand Up @@ -153,7 +153,7 @@ impl Qdrant {
/// # use orca::qdrant::Qdrant;
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let client = Qdrant::new("localhost", 6333);
/// # let client = Qdrant::new("http://localhost:6334");
/// let collection_name = "test_collection";
/// client.delete_collection(collection_name).await?;
/// # Ok(())
Expand Down Expand Up @@ -181,11 +181,11 @@ impl Qdrant {
/// # }
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let qdrant = Qdrant::new("localhost", 6333);
/// let client = Qdrant::new("http://localhost:6334");
/// let collection_name = "my_collection";
/// let vector = vec![0.1, 0.2, 0.3];
/// let payload = MyPayload { name: "John".to_string(), age: 30 };
/// qdrant.insert(collection_name, vector, payload).await?;
/// client.insert(collection_name, vector, payload).await?;
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -220,7 +220,7 @@ impl Qdrant {
/// #
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn Error>> {
/// # let client = Qdrant::new("localhost", 6333);
/// # let client = Qdrant::new("http://localhost:6334");
/// let vectors = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
/// let payloads = vec!["payload1".to_string(), "payload2".to_string()];
/// client.insert_many("collection_name", vectors, payloads).await?;
Expand Down Expand Up @@ -265,7 +265,7 @@ impl Qdrant {
/// # use orca::qdrant::{Qdrant, Condition};
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let client = Qdrant::new("localhost", 6333);
/// let client = Qdrant::new("http://localhost:6334");
/// let conditions = vec![Condition::Matches(
/// "age".into(),
/// 30.into(),
Expand Down Expand Up @@ -332,8 +332,7 @@ mod tests {
use rand::Rng;
use serde_json::json;

const TEST_HOST: &str = "localhost";
const TEST_PORT: u16 = 6334;
const URL: &str = "http://localhost:6334";

fn generate_unique_collection_name() -> String {
let rng = rand::thread_rng();
Expand All @@ -342,13 +341,13 @@ mod tests {
}

async fn teardown(collection_name: &str) {
let qdrant = Qdrant::new(TEST_HOST, TEST_PORT);
let qdrant = Qdrant::new(URL);
let _ = qdrant.delete_collection(collection_name).await;
}

#[tokio::test]
async fn test_create_collection() {
let qdrant = Qdrant::new(TEST_HOST, TEST_PORT);
let qdrant = Qdrant::new(URL);
let unique_collection_name = generate_unique_collection_name();

let result = qdrant.create_collection(&unique_collection_name, 128).await;
Expand All @@ -359,7 +358,7 @@ mod tests {

#[tokio::test]
async fn test_insert_point() {
let qdrant = Qdrant::new(TEST_HOST, TEST_PORT);
let qdrant = Qdrant::new(URL);
let unique_collection_name = generate_unique_collection_name();

qdrant.create_collection(&unique_collection_name, 3).await.unwrap();
Expand All @@ -375,7 +374,7 @@ mod tests {

#[tokio::test]
async fn test_search_points() {
let qdrant = Qdrant::new(TEST_HOST, TEST_PORT);
let qdrant = Qdrant::new(URL);
let unique_collection_name = generate_unique_collection_name();

qdrant.create_collection(&unique_collection_name, 3).await.unwrap();
Expand Down

0 comments on commit 07b1302

Please sign in to comment.