diff --git a/swiftide-core/src/indexing_traits.rs b/swiftide-core/src/indexing_traits.rs index 32954f0c..37f934f2 100644 --- a/swiftide-core/src/indexing_traits.rs +++ b/swiftide-core/src/indexing_traits.rs @@ -31,6 +31,26 @@ pub trait Transformer: Send + Sync { } } +#[async_trait] +impl Transformer for Box { + async fn transform_node(&self, node: Node) -> Result { + self.as_ref().transform_node(node).await + } + fn concurrency(&self) -> Option { + self.as_ref().concurrency() + } +} + +#[async_trait] +impl Transformer for &dyn Transformer { + async fn transform_node(&self, node: Node) -> Result { + (*self).transform_node(node).await + } + fn concurrency(&self) -> Option { + (*self).concurrency() + } +} + #[async_trait] /// Use a closure as a transformer impl Transformer for F @@ -66,10 +86,65 @@ where } } +#[async_trait] +impl BatchableTransformer for Box { + async fn batch_transform(&self, nodes: Vec) -> IndexingStream { + self.as_ref().batch_transform(nodes).await + } + fn concurrency(&self) -> Option { + self.as_ref().concurrency() + } +} + +#[async_trait] +impl BatchableTransformer for &dyn BatchableTransformer { + async fn batch_transform(&self, nodes: Vec) -> IndexingStream { + (*self).batch_transform(nodes).await + } + fn concurrency(&self) -> Option { + (*self).concurrency() + } +} + /// Starting point of a stream #[cfg_attr(feature = "test-utils", automock, doc(hidden))] pub trait Loader { fn into_stream(self) -> IndexingStream; + + /// Intended for use with Box + /// + /// Only needed if you use trait objects (Box) + /// + /// # Example + /// + /// ```ignore + /// fn into_stream_boxed(self: Box) -> IndexingStream { + /// self.into_stream() + /// } + /// ``` + fn into_stream_boxed(self: Box) -> IndexingStream { + unimplemented!("Please implement into_stream_boxed for your loader, it needs to be implemented on the concrete type") + } +} + +impl Loader for Box { + fn into_stream(self) -> IndexingStream { + Loader::into_stream_boxed(self) + } + + fn into_stream_boxed(self: Box) -> IndexingStream { + Loader::into_stream(*self) + } +} + +impl Loader for &dyn Loader { + fn into_stream(self) -> IndexingStream { + Loader::into_stream_boxed(Box::new(self)) + } + + fn into_stream_boxed(self: Box) -> IndexingStream { + Loader::into_stream(*self) + } } #[cfg_attr(feature = "test-utils", automock, doc(hidden))] @@ -84,6 +159,26 @@ pub trait ChunkerTransformer: Send + Sync + Debug { } } +#[async_trait] +impl ChunkerTransformer for Box { + async fn transform_node(&self, node: Node) -> IndexingStream { + self.as_ref().transform_node(node).await + } + fn concurrency(&self) -> Option { + self.as_ref().concurrency() + } +} + +#[async_trait] +impl ChunkerTransformer for &dyn ChunkerTransformer { + async fn transform_node(&self, node: Node) -> IndexingStream { + (*self).transform_node(node).await + } + fn concurrency(&self) -> Option { + (*self).concurrency() + } +} + #[cfg_attr(feature = "test-utils", automock)] #[async_trait] /// Caches nodes, typically by their path and hash @@ -95,6 +190,26 @@ pub trait NodeCache: Send + Sync + Debug { async fn set(&self, node: &Node); } +#[async_trait] +impl NodeCache for Box { + async fn get(&self, node: &Node) -> bool { + self.as_ref().get(node).await + } + async fn set(&self, node: &Node) { + self.as_ref().set(node).await; + } +} + +#[async_trait] +impl NodeCache for &dyn NodeCache { + async fn get(&self, node: &Node) -> bool { + (*self).get(node).await + } + async fn set(&self, node: &Node) { + (*self).set(node).await; + } +} + #[cfg_attr(feature = "test-utils", automock)] #[async_trait] /// Embeds a list of strings and returns its embeddings. @@ -103,6 +218,20 @@ pub trait EmbeddingModel: Send + Sync + Debug { async fn embed(&self, input: Vec) -> Result; } +#[async_trait] +impl EmbeddingModel for Box { + async fn embed(&self, input: Vec) -> Result { + self.as_ref().embed(input).await + } +} + +#[async_trait] +impl EmbeddingModel for &dyn EmbeddingModel { + async fn embed(&self, input: Vec) -> Result { + (*self).embed(input).await + } +} + #[cfg_attr(feature = "test-utils", automock)] #[async_trait] /// Embeds a list of strings and returns its embeddings. @@ -111,6 +240,20 @@ pub trait SparseEmbeddingModel: Send + Sync + Debug { async fn sparse_embed(&self, input: Vec) -> Result; } +#[async_trait] +impl SparseEmbeddingModel for Box { + async fn sparse_embed(&self, input: Vec) -> Result { + self.as_ref().sparse_embed(input).await + } +} + +#[async_trait] +impl SparseEmbeddingModel for &dyn SparseEmbeddingModel { + async fn sparse_embed(&self, input: Vec) -> Result { + (*self).sparse_embed(input).await + } +} + #[cfg_attr(feature = "test-utils", automock)] #[async_trait] /// Given a string prompt, queries an LLM @@ -119,6 +262,20 @@ pub trait SimplePrompt: Debug + Send + Sync { async fn prompt(&self, prompt: Prompt) -> Result; } +#[async_trait] +impl SimplePrompt for Box { + async fn prompt(&self, prompt: Prompt) -> Result { + self.as_ref().prompt(prompt).await + } +} + +#[async_trait] +impl SimplePrompt for &dyn SimplePrompt { + async fn prompt(&self, prompt: Prompt) -> Result { + (*self).prompt(prompt).await + } +} + #[cfg_attr(feature = "test-utils", automock)] #[async_trait] /// Persists nodes @@ -131,6 +288,38 @@ pub trait Persist: Debug + Send + Sync { } } +#[async_trait] +impl Persist for Box { + async fn setup(&self) -> Result<()> { + self.as_ref().setup().await + } + async fn store(&self, node: Node) -> Result { + self.as_ref().store(node).await + } + async fn batch_store(&self, nodes: Vec) -> IndexingStream { + self.as_ref().batch_store(nodes).await + } + fn batch_size(&self) -> Option { + self.as_ref().batch_size() + } +} + +#[async_trait] +impl Persist for &dyn Persist { + async fn setup(&self) -> Result<()> { + (*self).setup().await + } + async fn store(&self, node: Node) -> Result { + (*self).store(node).await + } + async fn batch_store(&self, nodes: Vec) -> IndexingStream { + (*self).batch_store(nodes).await + } + fn batch_size(&self) -> Option { + (*self).batch_size() + } +} + /// Allows for passing defaults from the pipeline to the transformer /// Required for batch transformers as at least a marker, implementation is not required pub trait WithIndexingDefaults { @@ -144,7 +333,17 @@ pub trait WithBatchIndexingDefaults { } impl WithIndexingDefaults for dyn Transformer {} +impl WithIndexingDefaults for Box { + fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) { + self.as_mut().with_indexing_defaults(indexing_defaults); + } +} impl WithBatchIndexingDefaults for dyn BatchableTransformer {} +impl WithBatchIndexingDefaults for Box { + fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) { + self.as_mut().with_indexing_defaults(indexing_defaults); + } +} impl WithIndexingDefaults for F where F: Fn(Node) -> Result {} impl WithBatchIndexingDefaults for F where F: Fn(Vec) -> IndexingStream {} diff --git a/swiftide-core/src/query_traits.rs b/swiftide-core/src/query_traits.rs index e447308c..f650c1b5 100644 --- a/swiftide-core/src/query_traits.rs +++ b/swiftide-core/src/query_traits.rs @@ -11,7 +11,7 @@ use crate::{ /// Can transform queries before retrieval #[async_trait] -pub trait TransformQuery: Send + Sync + ToOwned { +pub trait TransformQuery: Send + Sync { async fn transform_query( &self, query: Query, @@ -21,7 +21,7 @@ pub trait TransformQuery: Send + Sync + ToOwned { #[async_trait] impl TransformQuery for F where - F: Fn(Query) -> Result> + Send + Sync + ToOwned, + F: Fn(Query) -> Result> + Send + Sync, { async fn transform_query( &self, @@ -31,12 +31,22 @@ where } } +#[async_trait] +impl TransformQuery for Box { + async fn transform_query( + &self, + query: Query, + ) -> Result> { + self.as_ref().transform_query(query).await + } +} + /// A search strategy for the query pipeline pub trait SearchStrategy: Clone + Send + Sync + Default {} /// Can retrieve documents given a SearchStrategy #[async_trait] -pub trait Retrieve: Send + Sync + ToOwned { +pub trait Retrieve: Send + Sync { async fn retrieve( &self, search_strategy: &S, @@ -44,11 +54,22 @@ pub trait Retrieve: Send + Sync + ToOwned { ) -> Result>; } +#[async_trait] +impl Retrieve for Box> { + async fn retrieve( + &self, + search_strategy: &S, + query: Query, + ) -> Result> { + self.as_ref().retrieve(search_strategy, query).await + } +} + #[async_trait] impl Retrieve for F where S: SearchStrategy, - F: Fn(&S, Query) -> Result> + Send + Sync + ToOwned, + F: Fn(&S, Query) -> Result> + Send + Sync, { async fn retrieve( &self, @@ -61,7 +82,7 @@ where /// Can transform a response after retrieval #[async_trait] -pub trait TransformResponse: Send + Sync + ToOwned { +pub trait TransformResponse: Send + Sync { async fn transform_response(&self, query: Query) -> Result>; } @@ -69,29 +90,43 @@ pub trait TransformResponse: Send + Sync + ToOwned { #[async_trait] impl TransformResponse for F where - F: Fn(Query) -> Result> + Send + Sync + ToOwned, + F: Fn(Query) -> Result> + Send + Sync, { async fn transform_response(&self, query: Query) -> Result> { (self)(query) } } +#[async_trait] +impl TransformResponse for Box { + async fn transform_response(&self, query: Query) -> Result> { + self.as_ref().transform_response(query).await + } +} + /// Can answer the original query #[async_trait] -pub trait Answer: Send + Sync + ToOwned { +pub trait Answer: Send + Sync { async fn answer(&self, query: Query) -> Result>; } #[async_trait] impl Answer for F where - F: Fn(Query) -> Result> + Send + Sync + ToOwned, + F: Fn(Query) -> Result> + Send + Sync, { async fn answer(&self, query: Query) -> Result> { (self)(query) } } +#[async_trait] +impl Answer for Box { + async fn answer(&self, query: Query) -> Result> { + self.as_ref().answer(query).await + } +} + /// Evaluates a query /// /// An evaluator needs to be able to respond to each step in the query pipeline @@ -99,3 +134,10 @@ where pub trait EvaluateQuery: Send + Sync { async fn evaluate(&self, evaluation: QueryEvaluation) -> Result<()>; } + +#[async_trait] +impl EvaluateQuery for Box { + async fn evaluate(&self, evaluation: QueryEvaluation) -> Result<()> { + self.as_ref().evaluate(evaluation).await + } +} diff --git a/swiftide-indexing/src/loaders/file_loader.rs b/swiftide-indexing/src/loaders/file_loader.rs index d3b3b462..21dfead9 100644 --- a/swiftide-indexing/src/loaders/file_loader.rs +++ b/swiftide-indexing/src/loaders/file_loader.rs @@ -112,6 +112,10 @@ impl Loader for FileLoader { IndexingStream::iter(files) } + + fn into_stream_boxed(self: Box) -> IndexingStream { + self.into_stream() + } } #[cfg(test)]