Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(indexing)!: Removed duplication of batch_size. Pipeline owns the default ba… #336

Merged
merged 7 commits into from
Sep 26, 2024
2 changes: 1 addition & 1 deletion examples/fastembed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.to_owned();

indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"]))
.then_in_batch(10, Embed::new(FastEmbed::builder().batch_size(10).build()?))
.then_in_batch(Embed::new(FastEmbed::builder().batch_size(10).build()?))
.then_store_with(
Qdrant::try_from_url(qdrant_url)?
.batch_size(50)
Expand Down
2 changes: 1 addition & 1 deletion examples/fluvio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.unwrap();

indexing::Pipeline::from_loader(loader)
.then_in_batch(10, Embed::new(FastEmbed::try_default().unwrap()))
.then_in_batch(Embed::new(FastEmbed::try_default().unwrap()).with_batch_size(10))
.then_store_with(
Qdrant::builder()
.batch_size(50)
Expand Down
15 changes: 4 additions & 11 deletions examples/hybrid_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Ensure all batching is consistent
let batch_size = 64;

let fastembed_sparse = FastEmbed::try_default_sparse()
.unwrap()
.with_batch_size(batch_size)
.to_owned();
let fastembed = FastEmbed::try_default()
.unwrap()
.with_batch_size(batch_size)
.to_owned();
let fastembed_sparse = FastEmbed::try_default_sparse().unwrap().to_owned();
let fastembed = FastEmbed::try_default().unwrap().to_owned();

// Set up openai with the mini model, which is great for indexing
let openai = openai::OpenAI::builder()
Expand All @@ -57,10 +51,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Generate metadata on the code chunks to increase our chances of finding the right code
.then(MetadataQACode::from_client(openai.clone()).build().unwrap())
.then_in_batch(
batch_size,
transformers::SparseEmbed::new(fastembed_sparse.clone()),
transformers::SparseEmbed::new(fastembed_sparse.clone()).with_batch_size(batch_size),
)
.then_in_batch(batch_size, transformers::Embed::new(fastembed.clone()))
.then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(batch_size))
.then_store_with(qdrant.clone())
.run()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/index_codebase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
"rust",
10..2048,
)?)
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
.then_store_with(
Qdrant::builder()
.batch_size(50)
Expand Down
2 changes: 1 addition & 1 deletion examples/index_codebase_reduced_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.then(indexing::transformers::CompressCodeOutline::new(
openai_client.clone(),
))
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
.then_store_with(
Qdrant::builder()
.batch_size(50)
Expand Down
2 changes: 1 addition & 1 deletion examples/index_groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
indexing::Pipeline::from_loader(FileLoader::new("README.md"))
.then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
.then(MetadataQAText::new(groq_client.clone()))
.then_in_batch(10, Embed::new(fastembed))
.then_in_batch(Embed::new(fastembed).with_batch_size(10))
.then_store_with(memory_store.clone())
.run()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/index_markdown_lots_of_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.then(MetadataSummary::new(openai_client.clone()))
.then(MetadataTitle::new(openai_client.clone()))
.then(MetadataKeywords::new(openai_client.clone()))
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()))
.log_all()
.filter_errors()
.then_store_with(
Expand Down
2 changes: 1 addition & 1 deletion examples/index_ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
indexing::Pipeline::from_loader(FileLoader::new("README.md"))
.then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
.then(MetadataQAText::new(ollama_client.clone()))
.then_in_batch(10, Embed::new(fastembed))
.then_in_batch(Embed::new(fastembed).with_batch_size(10))
.then_store_with(memory_store.clone())
.run()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/lancedb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
indexing::Pipeline::from_loader(FileLoader::new("README.md"))
.then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
.then(MetadataQAText::new(openai_client.clone()))
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
.then_store_with(lancedb.clone())
.run()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/query_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
indexing::Pipeline::from_loader(FileLoader::new("README.md"))
.then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
.then(MetadataQAText::new(openai_client.clone()))
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
.then_store_with(qdrant.clone())
.run()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/store_multiple_vectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.then(MetadataSummary::new(openai_client.clone()))
.then(MetadataTitle::new(openai_client.clone()))
.then(MetadataKeywords::new(openai_client.clone()))
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
.log_all()
.filter_errors()
.then_store_with(
Expand Down
5 changes: 5 additions & 0 deletions swiftide-core/src/indexing_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ pub trait BatchableTransformer: Send + Sync {
let name = std::any::type_name::<Self>();
name.split("::").last().unwrap_or(name)
}

/// Overrides the default batch size of the pipeline
fn batch_size(&self) -> Option<usize> {
None
}
}

#[async_trait]
Expand Down
21 changes: 12 additions & 9 deletions swiftide-indexing/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ use std::{sync::Arc, time::Duration};

use swiftide_core::indexing::{EmbedMode, IndexingStream, Node};

/// The default batch size for batch processing.
const DEFAULT_BATCH_SIZE: usize = 256;

/// A pipeline for indexing files, adding metadata, chunking, transforming, embedding, and then storing them.
///
/// The `Pipeline` struct orchestrates the entire file indexing process. It is designed to be flexible and
Expand All @@ -27,6 +30,7 @@ pub struct Pipeline {
storage: Vec<Arc<dyn Persist>>,
concurrency: usize,
indexing_defaults: IndexingDefaults,
batch_size: usize,
}

impl Default for Pipeline {
Expand All @@ -37,6 +41,7 @@ impl Default for Pipeline {
storage: Vec::default(),
concurrency: num_cpus::get(),
indexing_defaults: IndexingDefaults::default(),
batch_size: DEFAULT_BATCH_SIZE,
}
}
}
Expand Down Expand Up @@ -207,7 +212,6 @@ impl Pipeline {
///
devsprint marked this conversation as resolved.
Show resolved Hide resolved
/// # Arguments
///
/// * `batch_size` - The size of the batches to be processed.
/// * `transformer` - A transformer that implements the `BatchableTransformer` trait.
///
/// # Returns
Expand All @@ -216,7 +220,6 @@ impl Pipeline {
#[must_use]
pub fn then_in_batch(
mut self,
batch_size: usize,
mut transformer: impl BatchableTransformer + WithBatchIndexingDefaults + 'static,
) -> Self {
let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
Expand All @@ -226,7 +229,7 @@ impl Pipeline {
let transformer = Arc::new(transformer);
self.stream = self
.stream
.try_chunks(batch_size)
.try_chunks(transformer.batch_size().unwrap_or(self.batch_size))
.map_ok(move |nodes| {
let transformer = Arc::clone(&transformer);
let span = tracing::trace_span!("then_in_batch", nodes = ?nodes );
Expand Down Expand Up @@ -406,13 +409,15 @@ impl Pipeline {
storage: self.storage.clone(),
concurrency: self.concurrency,
indexing_defaults: self.indexing_defaults.clone(),
batch_size: self.batch_size,
};

let right_pipeline = Self {
stream: right_rx.into(),
storage: self.storage.clone(),
concurrency: self.concurrency,
indexing_defaults: self.indexing_defaults.clone(),
batch_size: self.batch_size,
};

(left_pipeline, right_pipeline)
Expand Down Expand Up @@ -606,6 +611,7 @@ mod tests {
.returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok)));
batch_transformer.expect_concurrency().returning(|| None);
batch_transformer.expect_name().returning(|| "transformer");
batch_transformer.expect_batch_size().returning(|| None);

chunker
.expect_transform_node()
Expand Down Expand Up @@ -635,7 +641,7 @@ mod tests {

let pipeline = Pipeline::from_loader(loader)
.then(transformer)
.then_in_batch(1, batch_transformer)
.then_in_batch(batch_transformer)
.then_chunk(chunker)
.then_store_with(storage);

Expand Down Expand Up @@ -750,7 +756,7 @@ mod tests {
.returning(|| vec![Ok(Node::default())].into());

let pipeline = Pipeline::from_loader(loader)
.then_in_batch(10, batch_transformer)
.then_in_batch(batch_transformer)
.then_store_with(storage.clone());
pipeline.run().await.unwrap();

Expand Down Expand Up @@ -884,10 +890,7 @@ mod tests {

let pipeline = Pipeline::from_loader(Box::new(loader) as Box<dyn Loader>)
.then(Box::new(transformer) as Box<dyn Transformer>)
.then_in_batch(
1,
Box::new(batch_transformer) as Box<dyn BatchableTransformer>,
)
.then_in_batch(Box::new(batch_transformer) as Box<dyn BatchableTransformer>)
.then_chunk(Box::new(chunker) as Box<dyn ChunkerTransformer>)
.then_store_with(Box::new(storage) as Box<dyn Persist>);
pipeline.run().await.unwrap();
Expand Down
18 changes: 18 additions & 0 deletions swiftide-indexing/src/transformers/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ use swiftide_core::{
pub struct Embed {
embed_model: Arc<dyn EmbeddingModel>,
concurrency: Option<usize>,
batch_size: Option<usize>,
}

impl std::fmt::Debug for Embed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Embed")
.field("concurrency", &self.concurrency)
.field("batch_size", &self.batch_size)
.finish()
}
}
Expand All @@ -38,6 +40,7 @@ impl Embed {
Self {
embed_model: Arc::new(model),
concurrency: None,
batch_size: None,
}
}

Expand All @@ -46,6 +49,21 @@ impl Embed {
self.concurrency = Some(concurrency);
self
}

/// Sets the batch size for the transformer.
/// If the batch size is not set, the transformer will use the default batch size set by the pipeline
/// # Parameters
///
/// * `batch_size` - The batch size to use for the transformer.
///
/// # Returns
///
/// A new instance of `Embed`.
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
}

impl WithBatchIndexingDefaults for Embed {}
Expand Down
21 changes: 21 additions & 0 deletions swiftide-indexing/src/transformers/sparse_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use swiftide_core::{
pub struct SparseEmbed {
embed_model: Arc<dyn SparseEmbeddingModel>,
concurrency: Option<usize>,
batch_size: Option<usize>,
}

impl std::fmt::Debug for SparseEmbed {
Expand All @@ -38,6 +39,7 @@ impl SparseEmbed {
Self {
embed_model: Arc::new(model),
concurrency: None,
batch_size: None,
}
}

Expand All @@ -46,6 +48,21 @@ impl SparseEmbed {
self.concurrency = Some(concurrency);
self
}

/// Sets the batch size for the transformer.
/// If the batch size is not set, the transformer will use the default batch size set by the pipeline
/// # Parameters
///
/// * `batch_size` - The batch size to use for the transformer.
///
/// # Returns
///
/// A new instance of `Embed`.
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
}

impl WithBatchIndexingDefaults for SparseEmbed {}
Expand Down Expand Up @@ -114,6 +131,10 @@ impl BatchableTransformer for SparseEmbed {
fn concurrency(&self) -> Option<usize> {
self.concurrency
}

fn batch_size(&self) -> Option<usize> {
self.batch_size
}
}

#[cfg(test)]
Expand Down
5 changes: 0 additions & 5 deletions swiftide-integrations/src/fastembed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ impl FastEmbed {
.build()
}

pub fn with_batch_size(&mut self, batch_size: usize) -> &mut Self {
self.batch_size = Some(batch_size);
self
}

pub fn builder() -> FastEmbedBuilder {
FastEmbedBuilder::default()
}
Expand Down
4 changes: 2 additions & 2 deletions swiftide/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
//! Pipeline::from_loader(FileLoader::new(".").with_extensions(&["md"]))
//! .then_chunk(ChunkMarkdown::from_chunk_range(10..512))
//! .then(MetadataQAText::new(openai_client.clone()))
//! .then_in_batch(10, Embed::new(openai_client.clone()))
//! .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
//! .then_store_with(
//! Qdrant::try_from_url(qdrant_url)?
//! .batch_size(50)
Expand Down Expand Up @@ -165,7 +165,7 @@ pub mod indexing {
/// indexing::Pipeline::from_loader(FileLoader::new("README.md"))
/// .then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
/// .then(MetadataQAText::new(openai_client.clone()))
/// .then_in_batch(10, Embed::new(openai_client.clone()))
/// .then_in_batch(Embed::new(openai_client.clone()).with_batch_size(10))
/// .then_store_with(qdrant.clone())
/// .run()
/// .await?;
Expand Down
4 changes: 2 additions & 2 deletions swiftide/tests/indexing_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async fn test_indexing_pipeline() {
.then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap())
.then(transformers::MetadataQACode::default())
.filter_cached(integrations::redis::Redis::try_from_url(&redis_url, "prefix").unwrap())
.then_in_batch(1, transformers::Embed::new(openai_client.clone()))
.then_in_batch(transformers::Embed::new(openai_client.clone()).with_batch_size(1))
.log_nodes()
.then_store_with(
integrations::qdrant::Qdrant::try_from_url(&qdrant_url)
Expand Down Expand Up @@ -184,7 +184,7 @@ async fn test_named_vectors() {
.then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap())
.then(transformers::MetadataQACode::new(openai_client.clone()))
.filter_cached(integrations::redis::Redis::try_from_url(&redis_url, "prefix").unwrap())
.then_in_batch(10, transformers::Embed::new(openai_client.clone()))
.then_in_batch(transformers::Embed::new(openai_client.clone()).with_batch_size(10))
.then_store_with(
integrations::qdrant::Qdrant::try_from_url(&qdrant_url)
.unwrap()
Expand Down
2 changes: 1 addition & 1 deletion swiftide/tests/lancedb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async fn test_lancedb() {
.insert("filter".to_string(), "true".to_string());
Ok(node)
})
.then_in_batch(20, transformers::Embed::new(fastembed.clone()))
.then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20))
.log_nodes()
.then_store_with(lancedb.clone())
.run()
Expand Down
Loading
Loading