diff --git a/swiftide/src/ingestion/ingestion_pipeline.rs b/swiftide/src/ingestion/ingestion_pipeline.rs index dc98ad27..7e1639c3 100644 --- a/swiftide/src/ingestion/ingestion_pipeline.rs +++ b/swiftide/src/ingestion/ingestion_pipeline.rs @@ -108,6 +108,7 @@ impl IngestionPipeline { /// An instance of `IngestionPipeline` with the updated stream that applies the transformer to each node. pub fn then(mut self, transformer: impl Transformer + 'static) -> Self { let transformer = Arc::new(transformer); + let concurrency = transformer.concurrency().unwrap_or(self.concurrency); self.stream = self .stream .map_ok(move |node| { @@ -118,7 +119,7 @@ impl IngestionPipeline { ) .map_err(anyhow::Error::from) }) - .try_buffer_unordered(self.concurrency) + .try_buffer_unordered(concurrency) .map(|x| x.and_then(|x| x)) .boxed(); @@ -141,6 +142,7 @@ impl IngestionPipeline { transformer: impl BatchableTransformer + 'static, ) -> Self { let transformer = Arc::new(transformer); + let concurrency = transformer.concurrency().unwrap_or(self.concurrency); self.stream = self .stream .try_chunks(batch_size) @@ -154,8 +156,8 @@ impl IngestionPipeline { .map_err(anyhow::Error::from) }) .err_into::() - .try_buffer_unordered(self.concurrency) - .try_flatten_unordered(self.concurrency) + .try_buffer_unordered(concurrency) + .try_flatten_unordered(concurrency) .boxed(); self } @@ -171,6 +173,7 @@ impl IngestionPipeline { /// An instance of `IngestionPipeline` with the updated stream that applies the chunker transformer to each node. pub fn then_chunk(mut self, chunker: impl ChunkerTransformer + 'static) -> Self { let chunker = Arc::new(chunker); + let concurrency = chunker.concurrency().unwrap_or(self.concurrency); self.stream = self .stream .map_ok(move |node| { @@ -181,8 +184,8 @@ impl IngestionPipeline { ) .map_err(anyhow::Error::from) }) - .try_buffer_unordered(self.concurrency) - .try_flatten_unordered(self.concurrency) + .try_buffer_unordered(concurrency) + .try_flatten_unordered(concurrency) .boxed(); self @@ -215,7 +218,7 @@ impl IngestionPipeline { }) .err_into::() .try_buffer_unordered(self.concurrency) - .try_flatten() + .try_flatten_unordered(self.concurrency) .boxed(); } else { self.stream = self @@ -306,12 +309,14 @@ mod tests { node.chunk = "transformed".to_string(); Ok(node) }); + transformer.expect_concurrency().returning(|| None); batch_transformer .expect_batch_transform() .times(1) .in_sequence(&mut seq) .returning(|nodes| Box::pin(stream::iter(nodes.into_iter().map(Ok)))); + batch_transformer.expect_concurrency().returning(|| None); chunker .expect_transform_node() @@ -326,6 +331,7 @@ mod tests { } Box::pin(stream::iter(nodes)) }); + chunker.expect_concurrency().returning(|| None); storage.expect_setup().returning(|| Ok(())); storage.expect_batch_size().returning(|| None); diff --git a/swiftide/src/traits.rs b/swiftide/src/traits.rs index 4d076e84..50c2101d 100644 --- a/swiftide/src/traits.rs +++ b/swiftide/src/traits.rs @@ -13,16 +13,29 @@ use mockall::{automock, predicate::*}; /// Transforms single nodes into single nodes pub trait Transformer: Send + Sync + Debug { async fn transform_node(&self, node: IngestionNode) -> Result; + + /// Overrides the default concurrency of the pipeline + fn concurrency(&self) -> Option { + None + } } #[cfg_attr(test, automock)] #[async_trait] /// Transforms batched single nodes into streams of nodes pub trait BatchableTransformer: Send + Sync + Debug { + /// Defines the batch size for the transformer fn batch_size(&self) -> Option { None } + + /// Transforms a batch of nodes into a stream of nodes async fn batch_transform(&self, nodes: Vec) -> IngestionStream; + + /// Overrides the default concurrency of the pipeline + fn concurrency(&self) -> Option { + None + } } /// Starting point of a stream @@ -36,6 +49,11 @@ pub trait Loader { /// Turns one node into many nodes pub trait ChunkerTransformer: Send + Sync + Debug { async fn transform_node(&self, node: IngestionNode) -> IngestionStream; + + /// Overrides the default concurrency of the pipeline + fn concurrency(&self) -> Option { + None + } } #[cfg_attr(test, automock)] diff --git a/swiftide/src/transformers/chunk_code.rs b/swiftide/src/transformers/chunk_code.rs index 55ac48cc..bafbce2a 100644 --- a/swiftide/src/transformers/chunk_code.rs +++ b/swiftide/src/transformers/chunk_code.rs @@ -14,6 +14,7 @@ use crate::{ #[derive(Debug)] pub struct ChunkCode { chunker: CodeSplitter, + concurrency: Option, } impl ChunkCode { @@ -30,6 +31,7 @@ impl ChunkCode { pub fn try_for_language(lang: impl TryInto) -> Result { Ok(Self { chunker: CodeSplitter::builder().try_language(lang)?.build()?, + concurrency: None, }) } @@ -54,8 +56,14 @@ impl ChunkCode { .chunk_size(chunk_size) .build() .expect("Failed to build code splitter"), + concurrency: None, }) } + + pub fn with_concurrency(mut self, concurrency: usize) -> Self { + self.concurrency = Some(concurrency); + self + } } #[async_trait] @@ -87,4 +95,8 @@ impl ChunkerTransformer for ChunkCode { return stream::iter(vec![Err(split_result.unwrap_err())]).boxed(); } } + + fn concurrency(&self) -> Option { + self.concurrency + } } diff --git a/swiftide/src/transformers/chunk_markdown.rs b/swiftide/src/transformers/chunk_markdown.rs index f74fcdfc..3ada90af 100644 --- a/swiftide/src/transformers/chunk_markdown.rs +++ b/swiftide/src/transformers/chunk_markdown.rs @@ -1,25 +1,35 @@ use crate::{ingestion::IngestionNode, ingestion::IngestionStream, ChunkerTransformer}; use async_trait::async_trait; +use derive_builder::Builder; use futures_util::{stream, StreamExt}; use text_splitter::{Characters, MarkdownSplitter}; -#[derive(Debug)] +#[derive(Debug, Builder)] +#[builder(pattern = "owned")] pub struct ChunkMarkdown { chunker: MarkdownSplitter, + #[builder(default)] + concurrency: Option, } impl ChunkMarkdown { pub fn with_max_characters(max_characters: usize) -> Self { Self { chunker: MarkdownSplitter::new(max_characters), + concurrency: None, } } pub fn with_chunk_range(range: std::ops::Range) -> Self { Self { chunker: MarkdownSplitter::new(range), + concurrency: None, } } + + pub fn builder() -> ChunkMarkdownBuilder { + ChunkMarkdownBuilder::default() + } } #[async_trait] @@ -40,4 +50,8 @@ impl ChunkerTransformer for ChunkMarkdown { })) .boxed() } + + fn concurrency(&self) -> Option { + self.concurrency + } } diff --git a/swiftide/src/transformers/metadata_qa_code.rs b/swiftide/src/transformers/metadata_qa_code.rs index 492e1727..c73c1b15 100644 --- a/swiftide/src/transformers/metadata_qa_code.rs +++ b/swiftide/src/transformers/metadata_qa_code.rs @@ -13,6 +13,7 @@ pub struct MetadataQACode { client: Arc, prompt: String, num_questions: usize, + concurrency: Option, } impl MetadataQACode { @@ -30,8 +31,14 @@ impl MetadataQACode { client: Arc::new(client), prompt: default_prompt(), num_questions: 5, + concurrency: None, } } + + pub fn with_concurrency(mut self, concurrency: usize) -> Self { + self.concurrency = Some(concurrency); + self + } } /// Returns the default prompt template for generating questions and answers. @@ -110,4 +117,8 @@ impl Transformer for MetadataQACode { Ok(node) } + + fn concurrency(&self) -> Option { + self.concurrency + } } diff --git a/swiftide/src/transformers/metadata_qa_text.rs b/swiftide/src/transformers/metadata_qa_text.rs index 5fa70a72..b0872ded 100644 --- a/swiftide/src/transformers/metadata_qa_text.rs +++ b/swiftide/src/transformers/metadata_qa_text.rs @@ -18,6 +18,7 @@ pub struct MetadataQAText { client: Arc, prompt: String, num_questions: usize, + concurrency: Option, } impl MetadataQAText { @@ -35,8 +36,14 @@ impl MetadataQAText { client: Arc::new(client), prompt: default_prompt(), num_questions: 5, + concurrency: None, } } + + pub fn with_concurrency(mut self, concurrency: usize) -> Self { + self.concurrency = Some(concurrency); + self + } } /// Generates the default prompt template for generating questions and answers. @@ -111,4 +118,8 @@ impl Transformer for MetadataQAText { Ok(node) } + + fn concurrency(&self) -> Option { + self.concurrency + } } diff --git a/swiftide/src/transformers/openai_embed.rs b/swiftide/src/transformers/openai_embed.rs index ea7e8453..b0094195 100644 --- a/swiftide/src/transformers/openai_embed.rs +++ b/swiftide/src/transformers/openai_embed.rs @@ -16,6 +16,7 @@ use futures_util::{stream, StreamExt}; #[derive(Debug)] pub struct OpenAIEmbed { client: Arc, + concurrency: Option, } impl OpenAIEmbed { @@ -31,8 +32,14 @@ impl OpenAIEmbed { pub fn new(client: OpenAI) -> Self { Self { client: Arc::new(client), + concurrency: None, } } + + pub fn with_concurrency(mut self, concurrency: usize) -> Self { + self.concurrency = Some(concurrency); + self + } } #[async_trait] @@ -73,4 +80,8 @@ impl BatchableTransformer for OpenAIEmbed { ) .boxed() } + + fn concurrency(&self) -> Option { + self.concurrency + } } diff --git a/swiftide/tests/ingestion_pipeline.rs b/swiftide/tests/ingestion_pipeline.rs index 5775c2ad..b57d867b 100644 --- a/swiftide/tests/ingestion_pipeline.rs +++ b/swiftide/tests/ingestion_pipeline.rs @@ -126,9 +126,9 @@ async fn test_ingestion_pipeline() { // ); let qdrant_url = "http://localhost:6334"; - dbg!(&qdrant_url); - // dbg!(qdrant.stderr_to_vec().await.map(String::from_utf8).unwrap()); - // dbg!(qdrant.stdout_to_vec().await.map(String::from_utf8).unwrap()); + // Cleanup the collection before running the pipeline + let qdrant = QdrantClient::from_url(qdrant_url).build().unwrap(); + let _ = qdrant.delete_collection("swiftide-test").await; let result = IngestionPipeline::from_loader(FileLoader::new(tempdir.path()).with_extensions(&["rs"]))