Skip to content

Commit

Permalink
feat: Removed duplication of batch_size. Pipeline owns the default ba…
Browse files Browse the repository at this point in the history
…tch size value and Embeed/SparseEmbed are able to modify it. Fixes #223
  • Loading branch information
devsprint committed Sep 25, 2024
1 parent 1926aa9 commit 850ac89
Show file tree
Hide file tree
Showing 21 changed files with 77 additions and 54 deletions.
2 changes: 1 addition & 1 deletion examples/fastembed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.to_owned();

indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"]))
.then_in_batch(10, Embed::new(FastEmbed::builder().batch_size(10).build()?))
.then_in_batch(Embed::new(FastEmbed::builder().batch_size(10).build()?))
.then_store_with(
Qdrant::try_from_url(qdrant_url)?
.batch_size(50)
Expand Down
2 changes: 1 addition & 1 deletion examples/fluvio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.unwrap();

indexing::Pipeline::from_loader(loader)
.then_in_batch(10, Embed::new(FastEmbed::try_default().unwrap()))
.then_in_batch(Embed::new(FastEmbed::try_default().unwrap()).with_batch_size(10))
.then_store_with(
Qdrant::builder()
.batch_size(50)
Expand Down
15 changes: 4 additions & 11 deletions examples/hybrid_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Ensure all batching is consistent
let batch_size = 64;

let fastembed_sparse = FastEmbed::try_default_sparse()
.unwrap()
.with_batch_size(batch_size)
.to_owned();
let fastembed = FastEmbed::try_default()
.unwrap()
.with_batch_size(batch_size)
.to_owned();
let fastembed_sparse = FastEmbed::try_default_sparse().unwrap().to_owned();
let fastembed = FastEmbed::try_default().unwrap().to_owned();

// Set up openai with the mini model, which is great for indexing
let openai = openai::OpenAI::builder()
Expand All @@ -57,10 +51,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Generate metadata on the code chunks to increase our chances of finding the right code
.then(MetadataQACode::from_client(openai.clone()).build().unwrap())
.then_in_batch(
batch_size,
transformers::SparseEmbed::new(fastembed_sparse.clone()),
transformers::SparseEmbed::new(fastembed_sparse.clone()).with_batch_size(batch_size),
)
.then_in_batch(batch_size, transformers::Embed::new(fastembed.clone()))
.then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(batch_size))
.then_store_with(qdrant.clone())
.run()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/index_codebase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
"rust",
10..2048,
)?)
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
.then_store_with(
Qdrant::builder()
.batch_size(50)
Expand Down
2 changes: 1 addition & 1 deletion examples/index_codebase_reduced_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.then(indexing::transformers::CompressCodeOutline::new(
openai_client.clone(),
))
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
.then_store_with(
Qdrant::builder()
.batch_size(50)
Expand Down
2 changes: 1 addition & 1 deletion examples/index_groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
indexing::Pipeline::from_loader(FileLoader::new("README.md"))
.then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
.then(MetadataQAText::new(groq_client.clone()))
.then_in_batch(10, Embed::new(fastembed))
.then_in_batch(Embed::new(fastembed).with_batch_size(10))
.then_store_with(memory_store.clone())
.run()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/index_markdown_lots_of_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.then(MetadataSummary::new(openai_client.clone()))
.then(MetadataTitle::new(openai_client.clone()))
.then(MetadataKeywords::new(openai_client.clone()))
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()))
.log_all()
.filter_errors()
.then_store_with(
Expand Down
2 changes: 1 addition & 1 deletion examples/index_ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
indexing::Pipeline::from_loader(FileLoader::new("README.md"))
.then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
.then(MetadataQAText::new(ollama_client.clone()))
.then_in_batch(10, Embed::new(fastembed))
.then_in_batch(Embed::new(fastembed).with_batch_size(10))
.then_store_with(memory_store.clone())
.run()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/lancedb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
indexing::Pipeline::from_loader(FileLoader::new("README.md"))
.then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
.then(MetadataQAText::new(openai_client.clone()))
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
.then_store_with(lancedb.clone())
.run()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/query_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
indexing::Pipeline::from_loader(FileLoader::new("README.md"))
.then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
.then(MetadataQAText::new(openai_client.clone()))
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
.then_store_with(qdrant.clone())
.run()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/store_multiple_vectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.then(MetadataSummary::new(openai_client.clone()))
.then(MetadataTitle::new(openai_client.clone()))
.then(MetadataKeywords::new(openai_client.clone()))
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
.log_all()
.filter_errors()
.then_store_with(
Expand Down
5 changes: 5 additions & 0 deletions swiftide-core/src/indexing_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ pub trait BatchableTransformer: Send + Sync {
let name = std::any::type_name::<Self>();
name.split("::").last().unwrap_or(name)
}

/// Overrides the default batch size of the pipeline
fn batch_size(&self) -> Option<usize> {
None
}
}

#[async_trait]
Expand Down
18 changes: 9 additions & 9 deletions swiftide-indexing/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct Pipeline {
storage: Vec<Arc<dyn Persist>>,
concurrency: usize,
indexing_defaults: IndexingDefaults,
batch_size: usize,
}

impl Default for Pipeline {
Expand All @@ -37,6 +38,7 @@ impl Default for Pipeline {
storage: Vec::default(),
concurrency: num_cpus::get(),
indexing_defaults: IndexingDefaults::default(),
batch_size: 256, //TODO: make this configurable
}
}
}
Expand Down Expand Up @@ -207,7 +209,6 @@ impl Pipeline {
///
/// # Arguments
///
/// * `batch_size` - The size of the batches to be processed.
/// * `transformer` - A transformer that implements the `BatchableTransformer` trait.
///
/// # Returns
Expand All @@ -216,7 +217,6 @@ impl Pipeline {
#[must_use]
pub fn then_in_batch(
mut self,
batch_size: usize,
mut transformer: impl BatchableTransformer + WithBatchIndexingDefaults + 'static,
) -> Self {
let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
Expand All @@ -226,7 +226,7 @@ impl Pipeline {
let transformer = Arc::new(transformer);
self.stream = self
.stream
.try_chunks(batch_size)
.try_chunks(transformer.batch_size().unwrap_or(self.batch_size))
.map_ok(move |nodes| {
let transformer = Arc::clone(&transformer);
let span = tracing::trace_span!("then_in_batch", nodes = ?nodes );
Expand Down Expand Up @@ -406,13 +406,15 @@ impl Pipeline {
storage: self.storage.clone(),
concurrency: self.concurrency,
indexing_defaults: self.indexing_defaults.clone(),
batch_size: self.batch_size,
};

let right_pipeline = Self {
stream: right_rx.into(),
storage: self.storage.clone(),
concurrency: self.concurrency,
indexing_defaults: self.indexing_defaults.clone(),
batch_size: self.batch_size,
};

(left_pipeline, right_pipeline)
Expand Down Expand Up @@ -606,6 +608,7 @@ mod tests {
.returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok)));
batch_transformer.expect_concurrency().returning(|| None);
batch_transformer.expect_name().returning(|| "transformer");
batch_transformer.expect_batch_size().returning(|| None);

chunker
.expect_transform_node()
Expand Down Expand Up @@ -635,7 +638,7 @@ mod tests {

let pipeline = Pipeline::from_loader(loader)
.then(transformer)
.then_in_batch(1, batch_transformer)
.then_in_batch(batch_transformer)
.then_chunk(chunker)
.then_store_with(storage);

Expand Down Expand Up @@ -750,7 +753,7 @@ mod tests {
.returning(|| vec![Ok(Node::default())].into());

let pipeline = Pipeline::from_loader(loader)
.then_in_batch(10, batch_transformer)
.then_in_batch(batch_transformer)
.then_store_with(storage.clone());
pipeline.run().await.unwrap();

Expand Down Expand Up @@ -884,10 +887,7 @@ mod tests {

let pipeline = Pipeline::from_loader(Box::new(loader) as Box<dyn Loader>)
.then(Box::new(transformer) as Box<dyn Transformer>)
.then_in_batch(
1,
Box::new(batch_transformer) as Box<dyn BatchableTransformer>,
)
.then_in_batch(Box::new(batch_transformer) as Box<dyn BatchableTransformer>)
.then_chunk(Box::new(chunker) as Box<dyn ChunkerTransformer>)
.then_store_with(Box::new(storage) as Box<dyn Persist>);
pipeline.run().await.unwrap();
Expand Down
17 changes: 17 additions & 0 deletions swiftide-indexing/src/transformers/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ use swiftide_core::{
pub struct Embed {
embed_model: Arc<dyn EmbeddingModel>,
concurrency: Option<usize>,
batch_size: Option<usize>,
}

impl std::fmt::Debug for Embed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Embed")
.field("concurrency", &self.concurrency)
.field("batch_size", &self.batch_size)
.finish()
}
}
Expand All @@ -38,6 +40,7 @@ impl Embed {
Self {
embed_model: Arc::new(model),
concurrency: None,
batch_size: None,
}
}

Expand All @@ -46,6 +49,20 @@ impl Embed {
self.concurrency = Some(concurrency);
self
}

/// Sets the batch size for the transformer.
/// If the batch size is not set, the transformer will use the default batch size set by the pipeline
/// # Parameters
///
/// * `batch_size` - The batch size to use for the transformer.
///
/// # Returns
///
/// A new instance of `Embed`.
pub fn with_batch_size(mut self, batch_size: usize) -> Self {

Check failure on line 62 in swiftide-indexing/src/transformers/embed.rs

View workflow job for this annotation

GitHub Actions / Clippy

missing `#[must_use]` attribute on a method returning `Self`
self.batch_size = Some(batch_size);
self
}
}

impl WithBatchIndexingDefaults for Embed {}
Expand Down
20 changes: 20 additions & 0 deletions swiftide-indexing/src/transformers/sparse_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use swiftide_core::{
pub struct SparseEmbed {
embed_model: Arc<dyn SparseEmbeddingModel>,
concurrency: Option<usize>,
batch_size: Option<usize>,
}

impl std::fmt::Debug for SparseEmbed {
Expand All @@ -38,6 +39,7 @@ impl SparseEmbed {
Self {
embed_model: Arc::new(model),
concurrency: None,
batch_size: None,
}
}

Expand All @@ -46,6 +48,20 @@ impl SparseEmbed {
self.concurrency = Some(concurrency);
self
}

/// Sets the batch size for the transformer.
/// If the batch size is not set, the transformer will use the default batch size set by the pipeline
/// # Parameters
///
/// * `batch_size` - The batch size to use for the transformer.
///
/// # Returns
///
/// A new instance of `Embed`.
pub fn with_batch_size(mut self, batch_size: usize) -> Self {

Check failure on line 61 in swiftide-indexing/src/transformers/sparse_embed.rs

View workflow job for this annotation

GitHub Actions / Clippy

missing `#[must_use]` attribute on a method returning `Self`
self.batch_size = Some(batch_size);
self
}
}

impl WithBatchIndexingDefaults for SparseEmbed {}
Expand Down Expand Up @@ -114,6 +130,10 @@ impl BatchableTransformer for SparseEmbed {
fn concurrency(&self) -> Option<usize> {
self.concurrency
}

fn batch_size(&self) -> Option<usize> {
self.batch_size
}
}

#[cfg(test)]
Expand Down
5 changes: 0 additions & 5 deletions swiftide-integrations/src/fastembed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ impl FastEmbed {
.build()
}

pub fn with_batch_size(&mut self, batch_size: usize) -> &mut Self {
self.batch_size = Some(batch_size);
self
}

pub fn builder() -> FastEmbedBuilder {
FastEmbedBuilder::default()
}
Expand Down
4 changes: 2 additions & 2 deletions swiftide/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
//! Pipeline::from_loader(FileLoader::new(".").with_extensions(&["md"]))
//! .then_chunk(ChunkMarkdown::from_chunk_range(10..512))
//! .then(MetadataQAText::new(openai_client.clone()))
//! .then_in_batch(10, Embed::new(openai_client.clone()))
//! .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
//! .then_store_with(
//! Qdrant::try_from_url(qdrant_url)?
//! .batch_size(50)
Expand Down Expand Up @@ -165,7 +165,7 @@ pub mod indexing {
/// indexing::Pipeline::from_loader(FileLoader::new("README.md"))
/// .then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
/// .then(MetadataQAText::new(openai_client.clone()))
/// .then_in_batch(10, Embed::new(openai_client.clone()))
/// .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
/// .then_store_with(qdrant.clone())
/// .run()
/// .await?;
Expand Down
4 changes: 2 additions & 2 deletions swiftide/tests/indexing_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async fn test_indexing_pipeline() {
.then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap())
.then(transformers::MetadataQACode::default())
.filter_cached(integrations::redis::Redis::try_from_url(&redis_url, "prefix").unwrap())
.then_in_batch(1, transformers::Embed::new(openai_client.clone()))
.then_in_batch(transformers::Embed::new(openai_client.clone()).with_batch_size(1))
.log_nodes()
.then_store_with(
integrations::qdrant::Qdrant::try_from_url(&qdrant_url)
Expand Down Expand Up @@ -184,7 +184,7 @@ async fn test_named_vectors() {
.then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap())
.then(transformers::MetadataQACode::new(openai_client.clone()))
.filter_cached(integrations::redis::Redis::try_from_url(&redis_url, "prefix").unwrap())
.then_in_batch(10, transformers::Embed::new(openai_client.clone()))
.then_in_batch(transformers::Embed::new(openai_client.clone()).with_batch_size(10))
.then_store_with(
integrations::qdrant::Qdrant::try_from_url(&qdrant_url)
.unwrap()
Expand Down
2 changes: 1 addition & 1 deletion swiftide/tests/lancedb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async fn test_lancedb() {
.insert("filter".to_string(), "true".to_string());
Ok(node)
})
.then_in_batch(20, transformers::Embed::new(fastembed.clone()))
.then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20))
.log_nodes()
.then_store_with(lancedb.clone())
.run()
Expand Down
Loading

0 comments on commit 850ac89

Please sign in to comment.