Skip to content

Commit

Permalink
feat: configurable concurrency for transformers and chunkers (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
timonv authored Jun 14, 2024
1 parent cd055f1 commit fa74939
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 10 deletions.
18 changes: 12 additions & 6 deletions swiftide/src/ingestion/ingestion_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ impl IngestionPipeline {
/// An instance of `IngestionPipeline` with the updated stream that applies the transformer to each node.
pub fn then(mut self, transformer: impl Transformer + 'static) -> Self {
let transformer = Arc::new(transformer);
let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
self.stream = self
.stream
.map_ok(move |node| {
Expand All @@ -118,7 +119,7 @@ impl IngestionPipeline {
)
.map_err(anyhow::Error::from)
})
.try_buffer_unordered(self.concurrency)
.try_buffer_unordered(concurrency)
.map(|x| x.and_then(|x| x))
.boxed();

Expand All @@ -141,6 +142,7 @@ impl IngestionPipeline {
transformer: impl BatchableTransformer + 'static,
) -> Self {
let transformer = Arc::new(transformer);
let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
self.stream = self
.stream
.try_chunks(batch_size)
Expand All @@ -154,8 +156,8 @@ impl IngestionPipeline {
.map_err(anyhow::Error::from)
})
.err_into::<anyhow::Error>()
.try_buffer_unordered(self.concurrency)
.try_flatten_unordered(self.concurrency)
.try_buffer_unordered(concurrency)
.try_flatten_unordered(concurrency)
.boxed();
self
}
Expand All @@ -171,6 +173,7 @@ impl IngestionPipeline {
/// An instance of `IngestionPipeline` with the updated stream that applies the chunker transformer to each node.
pub fn then_chunk(mut self, chunker: impl ChunkerTransformer + 'static) -> Self {
let chunker = Arc::new(chunker);
let concurrency = chunker.concurrency().unwrap_or(self.concurrency);
self.stream = self
.stream
.map_ok(move |node| {
Expand All @@ -181,8 +184,8 @@ impl IngestionPipeline {
)
.map_err(anyhow::Error::from)
})
.try_buffer_unordered(self.concurrency)
.try_flatten_unordered(self.concurrency)
.try_buffer_unordered(concurrency)
.try_flatten_unordered(concurrency)
.boxed();

self
Expand Down Expand Up @@ -215,7 +218,7 @@ impl IngestionPipeline {
})
.err_into::<anyhow::Error>()
.try_buffer_unordered(self.concurrency)
.try_flatten()
.try_flatten_unordered(self.concurrency)
.boxed();
} else {
self.stream = self
Expand Down Expand Up @@ -306,12 +309,14 @@ mod tests {
node.chunk = "transformed".to_string();
Ok(node)
});
transformer.expect_concurrency().returning(|| None);

batch_transformer
.expect_batch_transform()
.times(1)
.in_sequence(&mut seq)
.returning(|nodes| Box::pin(stream::iter(nodes.into_iter().map(Ok))));
batch_transformer.expect_concurrency().returning(|| None);

chunker
.expect_transform_node()
Expand All @@ -326,6 +331,7 @@ mod tests {
}
Box::pin(stream::iter(nodes))
});
chunker.expect_concurrency().returning(|| None);

storage.expect_setup().returning(|| Ok(()));
storage.expect_batch_size().returning(|| None);
Expand Down
18 changes: 18 additions & 0 deletions swiftide/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,29 @@ use mockall::{automock, predicate::*};
/// Transforms single nodes into single nodes
pub trait Transformer: Send + Sync + Debug {
async fn transform_node(&self, node: IngestionNode) -> Result<IngestionNode>;

/// Overrides the default concurrency of the pipeline
fn concurrency(&self) -> Option<usize> {
None
}
}

#[cfg_attr(test, automock)]
#[async_trait]
/// Transforms batched single nodes into streams of nodes
pub trait BatchableTransformer: Send + Sync + Debug {
/// Defines the batch size for the transformer
fn batch_size(&self) -> Option<usize> {
None
}

/// Transforms a batch of nodes into a stream of nodes
async fn batch_transform(&self, nodes: Vec<IngestionNode>) -> IngestionStream;

/// Overrides the default concurrency of the pipeline
fn concurrency(&self) -> Option<usize> {
None
}
}

/// Starting point of a stream
Expand All @@ -36,6 +49,11 @@ pub trait Loader {
/// Turns one node into many nodes
pub trait ChunkerTransformer: Send + Sync + Debug {
async fn transform_node(&self, node: IngestionNode) -> IngestionStream;

/// Overrides the default concurrency of the pipeline
fn concurrency(&self) -> Option<usize> {
None
}
}

#[cfg_attr(test, automock)]
Expand Down
12 changes: 12 additions & 0 deletions swiftide/src/transformers/chunk_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::{
#[derive(Debug)]
pub struct ChunkCode {
chunker: CodeSplitter,
concurrency: Option<usize>,
}

impl ChunkCode {
Expand All @@ -30,6 +31,7 @@ impl ChunkCode {
pub fn try_for_language(lang: impl TryInto<SupportedLanguages>) -> Result<Self> {
Ok(Self {
chunker: CodeSplitter::builder().try_language(lang)?.build()?,
concurrency: None,
})
}

Expand All @@ -54,8 +56,14 @@ impl ChunkCode {
.chunk_size(chunk_size)
.build()
.expect("Failed to build code splitter"),
concurrency: None,
})
}

pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = Some(concurrency);
self
}
}

#[async_trait]
Expand Down Expand Up @@ -87,4 +95,8 @@ impl ChunkerTransformer for ChunkCode {
return stream::iter(vec![Err(split_result.unwrap_err())]).boxed();
}
}

fn concurrency(&self) -> Option<usize> {
self.concurrency
}
}
16 changes: 15 additions & 1 deletion swiftide/src/transformers/chunk_markdown.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
use crate::{ingestion::IngestionNode, ingestion::IngestionStream, ChunkerTransformer};
use async_trait::async_trait;
use derive_builder::Builder;
use futures_util::{stream, StreamExt};
use text_splitter::{Characters, MarkdownSplitter};

#[derive(Debug)]
#[derive(Debug, Builder)]
#[builder(pattern = "owned")]
pub struct ChunkMarkdown {
chunker: MarkdownSplitter<Characters>,
#[builder(default)]
concurrency: Option<usize>,
}

impl ChunkMarkdown {
pub fn with_max_characters(max_characters: usize) -> Self {
Self {
chunker: MarkdownSplitter::new(max_characters),
concurrency: None,
}
}

pub fn with_chunk_range(range: std::ops::Range<usize>) -> Self {
Self {
chunker: MarkdownSplitter::new(range),
concurrency: None,
}
}

pub fn builder() -> ChunkMarkdownBuilder {
ChunkMarkdownBuilder::default()
}
}

#[async_trait]
Expand All @@ -40,4 +50,8 @@ impl ChunkerTransformer for ChunkMarkdown {
}))
.boxed()
}

fn concurrency(&self) -> Option<usize> {
self.concurrency
}
}
11 changes: 11 additions & 0 deletions swiftide/src/transformers/metadata_qa_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub struct MetadataQACode {
client: Arc<dyn SimplePrompt>,
prompt: String,
num_questions: usize,
concurrency: Option<usize>,
}

impl MetadataQACode {
Expand All @@ -30,8 +31,14 @@ impl MetadataQACode {
client: Arc::new(client),
prompt: default_prompt(),
num_questions: 5,
concurrency: None,
}
}

pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = Some(concurrency);
self
}
}

/// Returns the default prompt template for generating questions and answers.
Expand Down Expand Up @@ -110,4 +117,8 @@ impl Transformer for MetadataQACode {

Ok(node)
}

fn concurrency(&self) -> Option<usize> {
self.concurrency
}
}
11 changes: 11 additions & 0 deletions swiftide/src/transformers/metadata_qa_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct MetadataQAText {
client: Arc<dyn SimplePrompt>,
prompt: String,
num_questions: usize,
concurrency: Option<usize>,
}

impl MetadataQAText {
Expand All @@ -35,8 +36,14 @@ impl MetadataQAText {
client: Arc::new(client),
prompt: default_prompt(),
num_questions: 5,
concurrency: None,
}
}

pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = Some(concurrency);
self
}
}

/// Generates the default prompt template for generating questions and answers.
Expand Down Expand Up @@ -111,4 +118,8 @@ impl Transformer for MetadataQAText {

Ok(node)
}

fn concurrency(&self) -> Option<usize> {
self.concurrency
}
}
11 changes: 11 additions & 0 deletions swiftide/src/transformers/openai_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use futures_util::{stream, StreamExt};
#[derive(Debug)]
pub struct OpenAIEmbed {
client: Arc<OpenAI>,
concurrency: Option<usize>,
}

impl OpenAIEmbed {
Expand All @@ -31,8 +32,14 @@ impl OpenAIEmbed {
pub fn new(client: OpenAI) -> Self {
Self {
client: Arc::new(client),
concurrency: None,
}
}

pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = Some(concurrency);
self
}
}

#[async_trait]
Expand Down Expand Up @@ -73,4 +80,8 @@ impl BatchableTransformer for OpenAIEmbed {
)
.boxed()
}

fn concurrency(&self) -> Option<usize> {
self.concurrency
}
}
6 changes: 3 additions & 3 deletions swiftide/tests/ingestion_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ async fn test_ingestion_pipeline() {
// );
let qdrant_url = "http://localhost:6334";

dbg!(&qdrant_url);
// dbg!(qdrant.stderr_to_vec().await.map(String::from_utf8).unwrap());
// dbg!(qdrant.stdout_to_vec().await.map(String::from_utf8).unwrap());
// Cleanup the collection before running the pipeline
let qdrant = QdrantClient::from_url(qdrant_url).build().unwrap();
let _ = qdrant.delete_collection("swiftide-test").await;

let result =
IngestionPipeline::from_loader(FileLoader::new(tempdir.path()).with_extensions(&["rs"]))
Expand Down

0 comments on commit fa74939

Please sign in to comment.