diff --git a/examples/fastembed.rs b/examples/fastembed.rs index df9c55c4..3c06dab1 100644 --- a/examples/fastembed.rs +++ b/examples/fastembed.rs @@ -27,7 +27,7 @@ async fn main() -> Result<(), Box> { .to_owned(); indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"])) - .then_in_batch(10, Embed::new(FastEmbed::builder().batch_size(10).build()?)) + .then_in_batch(Embed::new(FastEmbed::builder().batch_size(10).build()?)) .then_store_with( Qdrant::try_from_url(qdrant_url)? .batch_size(50) diff --git a/examples/fluvio.rs b/examples/fluvio.rs index a1518ff0..36b9578a 100644 --- a/examples/fluvio.rs +++ b/examples/fluvio.rs @@ -48,7 +48,7 @@ async fn main() -> Result<(), Box> { .unwrap(); indexing::Pipeline::from_loader(loader) - .then_in_batch(10, Embed::new(FastEmbed::try_default().unwrap())) + .then_in_batch(Embed::new(FastEmbed::try_default().unwrap()).with_batch_size(10)) .then_store_with( Qdrant::builder() .batch_size(50) diff --git a/examples/hybrid_search.rs b/examples/hybrid_search.rs index b5d96ded..f7c513df 100644 --- a/examples/hybrid_search.rs +++ b/examples/hybrid_search.rs @@ -23,14 +23,8 @@ async fn main() -> Result<(), Box> { // Ensure all batching is consistent let batch_size = 64; - let fastembed_sparse = FastEmbed::try_default_sparse() - .unwrap() - .with_batch_size(batch_size) - .to_owned(); - let fastembed = FastEmbed::try_default() - .unwrap() - .with_batch_size(batch_size) - .to_owned(); + let fastembed_sparse = FastEmbed::try_default_sparse().unwrap().to_owned(); + let fastembed = FastEmbed::try_default().unwrap().to_owned(); // Set up openai with the mini model, which is great for indexing let openai = openai::OpenAI::builder() @@ -57,10 +51,9 @@ async fn main() -> Result<(), Box> { // Generate metadata on the code chunks to increase our chances of finding the right code .then(MetadataQACode::from_client(openai.clone()).build().unwrap()) .then_in_batch( - batch_size, - transformers::SparseEmbed::new(fastembed_sparse.clone()), + transformers::SparseEmbed::new(fastembed_sparse.clone()).with_batch_size(batch_size), ) - .then_in_batch(batch_size, transformers::Embed::new(fastembed.clone())) + .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(batch_size)) .then_store_with(qdrant.clone()) .run() .await?; diff --git a/examples/index_codebase.rs b/examples/index_codebase.rs index 0fd7871e..6287c1cd 100644 --- a/examples/index_codebase.rs +++ b/examples/index_codebase.rs @@ -47,7 +47,7 @@ async fn main() -> Result<(), Box> { "rust", 10..2048, )?) - .then_in_batch(10, Embed::new(openai_client.clone())) + .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) .then_store_with( Qdrant::builder() .batch_size(50) diff --git a/examples/index_codebase_reduced_context.rs b/examples/index_codebase_reduced_context.rs index 05ceb1af..5125b883 100644 --- a/examples/index_codebase_reduced_context.rs +++ b/examples/index_codebase_reduced_context.rs @@ -61,7 +61,7 @@ async fn main() -> Result<(), Box> { .then(indexing::transformers::CompressCodeOutline::new( openai_client.clone(), )) - .then_in_batch(10, Embed::new(openai_client.clone())) + .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) .then_store_with( Qdrant::builder() .batch_size(50) diff --git a/examples/index_groq.rs b/examples/index_groq.rs index 95c8adad..279da2fc 100644 --- a/examples/index_groq.rs +++ b/examples/index_groq.rs @@ -35,7 +35,7 @@ async fn main() -> Result<(), Box> { indexing::Pipeline::from_loader(FileLoader::new("README.md")) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(groq_client.clone())) - .then_in_batch(10, Embed::new(fastembed)) + .then_in_batch(Embed::new(fastembed).with_batch_size(10)) .then_store_with(memory_store.clone()) .run() .await?; diff --git a/examples/index_markdown_lots_of_metadata.rs b/examples/index_markdown_lots_of_metadata.rs index d90ef0f9..ab1624d1 100644 --- a/examples/index_markdown_lots_of_metadata.rs +++ b/examples/index_markdown_lots_of_metadata.rs @@ -40,7 +40,7 @@ async fn main() -> Result<(), Box> { .then(MetadataSummary::new(openai_client.clone())) .then(MetadataTitle::new(openai_client.clone())) .then(MetadataKeywords::new(openai_client.clone())) - .then_in_batch(10, Embed::new(openai_client.clone())) + .then_in_batch(Embed::new(openai_client.clone())) .log_all() .filter_errors() .then_store_with( diff --git a/examples/index_ollama.rs b/examples/index_ollama.rs index 62bf3d8d..3e0f68e7 100644 --- a/examples/index_ollama.rs +++ b/examples/index_ollama.rs @@ -35,7 +35,7 @@ async fn main() -> Result<(), Box> { indexing::Pipeline::from_loader(FileLoader::new("README.md")) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(ollama_client.clone())) - .then_in_batch(10, Embed::new(fastembed)) + .then_in_batch(Embed::new(fastembed).with_batch_size(10)) .then_store_with(memory_store.clone()) .run() .await?; diff --git a/examples/lancedb.rs b/examples/lancedb.rs index 026a7404..b7d5cd97 100644 --- a/examples/lancedb.rs +++ b/examples/lancedb.rs @@ -40,7 +40,7 @@ async fn main() -> Result<(), Box> { indexing::Pipeline::from_loader(FileLoader::new("README.md")) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(openai_client.clone())) - .then_in_batch(10, Embed::new(openai_client.clone())) + .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) .then_store_with(lancedb.clone()) .run() .await?; diff --git a/examples/query_pipeline.rs b/examples/query_pipeline.rs index cf268ae8..cc5aabda 100644 --- a/examples/query_pipeline.rs +++ b/examples/query_pipeline.rs @@ -26,7 +26,7 @@ async fn main() -> Result<(), Box> { indexing::Pipeline::from_loader(FileLoader::new("README.md")) .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) .then(MetadataQAText::new(openai_client.clone())) - .then_in_batch(10, Embed::new(openai_client.clone())) + .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) .then_store_with(qdrant.clone()) .run() .await?; diff --git a/examples/store_multiple_vectors.rs b/examples/store_multiple_vectors.rs index a9a00167..3d7eacb7 100644 --- a/examples/store_multiple_vectors.rs +++ b/examples/store_multiple_vectors.rs @@ -46,7 +46,7 @@ async fn main() -> Result<(), Box> { .then(MetadataSummary::new(openai_client.clone())) .then(MetadataTitle::new(openai_client.clone())) .then(MetadataKeywords::new(openai_client.clone())) - .then_in_batch(10, Embed::new(openai_client.clone())) + .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) .log_all() .filter_errors() .then_store_with( diff --git a/swiftide-core/src/indexing_traits.rs b/swiftide-core/src/indexing_traits.rs index 9b41a908..a43c4215 100644 --- a/swiftide-core/src/indexing_traits.rs +++ b/swiftide-core/src/indexing_traits.rs @@ -83,6 +83,11 @@ pub trait BatchableTransformer: Send + Sync { let name = std::any::type_name::(); name.split("::").last().unwrap_or(name) } + + /// Overrides the default batch size of the pipeline + fn batch_size(&self) -> Option { + None + } } #[async_trait] diff --git a/swiftide-indexing/src/pipeline.rs b/swiftide-indexing/src/pipeline.rs index 3945a808..d1a9988e 100644 --- a/swiftide-indexing/src/pipeline.rs +++ b/swiftide-indexing/src/pipeline.rs @@ -11,6 +11,9 @@ use std::{sync::Arc, time::Duration}; use swiftide_core::indexing::{EmbedMode, IndexingStream, Node}; +/// The default batch size for batch processing. +const DEFAULT_BATCH_SIZE: usize = 256; + /// A pipeline for indexing files, adding metadata, chunking, transforming, embedding, and then storing them. /// /// The `Pipeline` struct orchestrates the entire file indexing process. It is designed to be flexible and @@ -27,6 +30,7 @@ pub struct Pipeline { storage: Vec>, concurrency: usize, indexing_defaults: IndexingDefaults, + batch_size: usize, } impl Default for Pipeline { @@ -37,6 +41,7 @@ impl Default for Pipeline { storage: Vec::default(), concurrency: num_cpus::get(), indexing_defaults: IndexingDefaults::default(), + batch_size: DEFAULT_BATCH_SIZE, } } } @@ -203,11 +208,10 @@ impl Pipeline { /// Adds a batch transformer to the pipeline. /// - /// Closures can also be provided as batch transformers. + /// If the transformer has a batch size set, the batch size from the transformer is used, otherwise the pipeline default batch size ([`DEFAULT_BATCH_SIZE`]). /// /// # Arguments /// - /// * `batch_size` - The size of the batches to be processed. /// * `transformer` - A transformer that implements the `BatchableTransformer` trait. /// /// # Returns @@ -216,7 +220,6 @@ impl Pipeline { #[must_use] pub fn then_in_batch( mut self, - batch_size: usize, mut transformer: impl BatchableTransformer + WithBatchIndexingDefaults + 'static, ) -> Self { let concurrency = transformer.concurrency().unwrap_or(self.concurrency); @@ -226,7 +229,7 @@ impl Pipeline { let transformer = Arc::new(transformer); self.stream = self .stream - .try_chunks(batch_size) + .try_chunks(transformer.batch_size().unwrap_or(self.batch_size)) .map_ok(move |nodes| { let transformer = Arc::clone(&transformer); let span = tracing::trace_span!("then_in_batch", nodes = ?nodes ); @@ -406,6 +409,7 @@ impl Pipeline { storage: self.storage.clone(), concurrency: self.concurrency, indexing_defaults: self.indexing_defaults.clone(), + batch_size: self.batch_size, }; let right_pipeline = Self { @@ -413,6 +417,7 @@ impl Pipeline { storage: self.storage.clone(), concurrency: self.concurrency, indexing_defaults: self.indexing_defaults.clone(), + batch_size: self.batch_size, }; (left_pipeline, right_pipeline) @@ -606,6 +611,7 @@ mod tests { .returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok))); batch_transformer.expect_concurrency().returning(|| None); batch_transformer.expect_name().returning(|| "transformer"); + batch_transformer.expect_batch_size().returning(|| None); chunker .expect_transform_node() @@ -635,7 +641,7 @@ mod tests { let pipeline = Pipeline::from_loader(loader) .then(transformer) - .then_in_batch(1, batch_transformer) + .then_in_batch(batch_transformer) .then_chunk(chunker) .then_store_with(storage); @@ -750,7 +756,7 @@ mod tests { .returning(|| vec![Ok(Node::default())].into()); let pipeline = Pipeline::from_loader(loader) - .then_in_batch(10, batch_transformer) + .then_in_batch(batch_transformer) .then_store_with(storage.clone()); pipeline.run().await.unwrap(); @@ -884,10 +890,7 @@ mod tests { let pipeline = Pipeline::from_loader(Box::new(loader) as Box) .then(Box::new(transformer) as Box) - .then_in_batch( - 1, - Box::new(batch_transformer) as Box, - ) + .then_in_batch(Box::new(batch_transformer) as Box) .then_chunk(Box::new(chunker) as Box) .then_store_with(Box::new(storage) as Box); pipeline.run().await.unwrap(); diff --git a/swiftide-indexing/src/transformers/embed.rs b/swiftide-indexing/src/transformers/embed.rs index c97c96a9..506aa6b4 100644 --- a/swiftide-indexing/src/transformers/embed.rs +++ b/swiftide-indexing/src/transformers/embed.rs @@ -14,12 +14,14 @@ use swiftide_core::{ pub struct Embed { embed_model: Arc, concurrency: Option, + batch_size: Option, } impl std::fmt::Debug for Embed { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Embed") .field("concurrency", &self.concurrency) + .field("batch_size", &self.batch_size) .finish() } } @@ -38,6 +40,7 @@ impl Embed { Self { embed_model: Arc::new(model), concurrency: None, + batch_size: None, } } @@ -46,6 +49,21 @@ impl Embed { self.concurrency = Some(concurrency); self } + + /// Sets the batch size for the transformer. + /// If the batch size is not set, the transformer will use the default batch size set by the pipeline + /// # Parameters + /// + /// * `batch_size` - The batch size to use for the transformer. + /// + /// # Returns + /// + /// A new instance of `Embed`. + #[must_use] + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = Some(batch_size); + self + } } impl WithBatchIndexingDefaults for Embed {} diff --git a/swiftide-indexing/src/transformers/sparse_embed.rs b/swiftide-indexing/src/transformers/sparse_embed.rs index 59aa37de..cbe7256e 100644 --- a/swiftide-indexing/src/transformers/sparse_embed.rs +++ b/swiftide-indexing/src/transformers/sparse_embed.rs @@ -14,6 +14,7 @@ use swiftide_core::{ pub struct SparseEmbed { embed_model: Arc, concurrency: Option, + batch_size: Option, } impl std::fmt::Debug for SparseEmbed { @@ -38,6 +39,7 @@ impl SparseEmbed { Self { embed_model: Arc::new(model), concurrency: None, + batch_size: None, } } @@ -46,6 +48,21 @@ impl SparseEmbed { self.concurrency = Some(concurrency); self } + + /// Sets the batch size for the transformer. + /// If the batch size is not set, the transformer will use the default batch size set by the pipeline + /// # Parameters + /// + /// * `batch_size` - The batch size to use for the transformer. + /// + /// # Returns + /// + /// A new instance of `Embed`. + #[must_use] + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = Some(batch_size); + self + } } impl WithBatchIndexingDefaults for SparseEmbed {} @@ -114,6 +131,10 @@ impl BatchableTransformer for SparseEmbed { fn concurrency(&self) -> Option { self.concurrency } + + fn batch_size(&self) -> Option { + self.batch_size + } } #[cfg(test)] diff --git a/swiftide-integrations/src/fastembed/mod.rs b/swiftide-integrations/src/fastembed/mod.rs index 8200227a..66190582 100644 --- a/swiftide-integrations/src/fastembed/mod.rs +++ b/swiftide-integrations/src/fastembed/mod.rs @@ -100,11 +100,6 @@ impl FastEmbed { .build() } - pub fn with_batch_size(&mut self, batch_size: usize) -> &mut Self { - self.batch_size = Some(batch_size); - self - } - pub fn builder() -> FastEmbedBuilder { FastEmbedBuilder::default() } diff --git a/swiftide/src/lib.rs b/swiftide/src/lib.rs index 92c983c1..d7a397e3 100644 --- a/swiftide/src/lib.rs +++ b/swiftide/src/lib.rs @@ -41,7 +41,7 @@ //! Pipeline::from_loader(FileLoader::new(".").with_extensions(&["md"])) //! .then_chunk(ChunkMarkdown::from_chunk_range(10..512)) //! .then(MetadataQAText::new(openai_client.clone())) -//! .then_in_batch(10, Embed::new(openai_client.clone())) +//! .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) //! .then_store_with( //! Qdrant::try_from_url(qdrant_url)? //! .batch_size(50) @@ -165,7 +165,7 @@ pub mod indexing { /// indexing::Pipeline::from_loader(FileLoader::new("README.md")) /// .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) /// .then(MetadataQAText::new(openai_client.clone())) -/// .then_in_batch(10, Embed::new(openai_client.clone())) +/// .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10)) /// .then_store_with(qdrant.clone()) /// .run() /// .await?; diff --git a/swiftide/tests/indexing_pipeline.rs b/swiftide/tests/indexing_pipeline.rs index 1898f110..204fe457 100644 --- a/swiftide/tests/indexing_pipeline.rs +++ b/swiftide/tests/indexing_pipeline.rs @@ -54,7 +54,7 @@ async fn test_indexing_pipeline() { .then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap()) .then(transformers::MetadataQACode::default()) .filter_cached(integrations::redis::Redis::try_from_url(&redis_url, "prefix").unwrap()) - .then_in_batch(1, transformers::Embed::new(openai_client.clone())) + .then_in_batch(transformers::Embed::new(openai_client.clone()).with_batch_size(1)) .log_nodes() .then_store_with( integrations::qdrant::Qdrant::try_from_url(&qdrant_url) @@ -184,7 +184,7 @@ async fn test_named_vectors() { .then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap()) .then(transformers::MetadataQACode::new(openai_client.clone())) .filter_cached(integrations::redis::Redis::try_from_url(&redis_url, "prefix").unwrap()) - .then_in_batch(10, transformers::Embed::new(openai_client.clone())) + .then_in_batch(transformers::Embed::new(openai_client.clone()).with_batch_size(10)) .then_store_with( integrations::qdrant::Qdrant::try_from_url(&qdrant_url) .unwrap() diff --git a/swiftide/tests/lancedb.rs b/swiftide/tests/lancedb.rs index c91c4b17..5464b5fd 100644 --- a/swiftide/tests/lancedb.rs +++ b/swiftide/tests/lancedb.rs @@ -52,7 +52,7 @@ async fn test_lancedb() { .insert("filter".to_string(), "true".to_string()); Ok(node) }) - .then_in_batch(20, transformers::Embed::new(fastembed.clone())) + .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20)) .log_nodes() .then_store_with(lancedb.clone()) .run() diff --git a/swiftide/tests/query_pipeline.rs b/swiftide/tests/query_pipeline.rs index 139ddb01..46be3d2d 100644 --- a/swiftide/tests/query_pipeline.rs +++ b/swiftide/tests/query_pipeline.rs @@ -38,7 +38,7 @@ async fn test_query_pipeline() { loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"]), ) .then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap()) - .then_in_batch(1, transformers::Embed::new(fastembed.clone())) + .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(1)) .then_store_with(qdrant_client.clone()) .run() .await @@ -89,14 +89,8 @@ async fn test_hybrid_search_qdrant() { .build() .unwrap(); - let fastembed_sparse = FastEmbed::try_default_sparse() - .unwrap() - .with_batch_size(batch_size) - .to_owned(); - let fastembed = FastEmbed::try_default() - .unwrap() - .with_batch_size(batch_size) - .to_owned(); + let fastembed_sparse = FastEmbed::try_default_sparse().unwrap().to_owned(); + let fastembed = FastEmbed::try_default().unwrap().to_owned(); println!("Qdrant URL: {qdrant_url}"); @@ -104,10 +98,9 @@ async fn test_hybrid_search_qdrant() { loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"]), ) .then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap()) - .then_in_batch(batch_size, transformers::Embed::new(fastembed.clone())) + .then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(batch_size)) .then_in_batch( - batch_size, - transformers::SparseEmbed::new(fastembed_sparse.clone()), + transformers::SparseEmbed::new(fastembed_sparse.clone()).with_batch_size(batch_size), ) .then_store_with(qdrant_client.clone()) .run() diff --git a/swiftide/tests/sparse_embeddings_and_hybrid_search.rs b/swiftide/tests/sparse_embeddings_and_hybrid_search.rs index d7324801..fd789fa9 100644 --- a/swiftide/tests/sparse_embeddings_and_hybrid_search.rs +++ b/swiftide/tests/sparse_embeddings_and_hybrid_search.rs @@ -48,8 +48,8 @@ async fn test_sparse_indexing_pipeline() { let result = Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"])) .then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap()) - .then_in_batch(20, transformers::SparseEmbed::new(fastembed_sparse)) - .then_in_batch(20, transformers::Embed::new(fastembed)) + .then_in_batch(transformers::SparseEmbed::new(fastembed_sparse).with_batch_size(20)) + .then_in_batch(transformers::Embed::new(fastembed).with_batch_size(20)) .log_nodes() .then_store_with( integrations::qdrant::Qdrant::try_from_url(&qdrant_url)