From 9004323dc5b11a3556a47e11fb8912ffc49f1e9e Mon Sep 17 00:00:00 2001 From: Timon Vonk Date: Sun, 23 Jun 2024 12:57:20 +0200 Subject: [PATCH] feat(integrations)!: implement Persist for Redis (#80) --- Cargo.lock | 1 + examples/Cargo.toml | 5 + examples/ingest_codebase.rs | 7 +- examples/ingest_into_redis.rs | 51 +++++ swiftide/src/ingestion/ingestion_node.rs | 4 +- swiftide/src/integrations/redis/mod.rs | 200 +++++++++++++++++- swiftide/src/integrations/redis/node_cache.rs | 132 +----------- swiftide/src/integrations/redis/persist.rs | 186 ++++++++++++++++ swiftide/tests/ingestion_pipeline.rs | 4 +- 9 files changed, 451 insertions(+), 139 deletions(-) create mode 100644 examples/ingest_into_redis.rs create mode 100644 swiftide/src/integrations/redis/persist.rs diff --git a/Cargo.lock b/Cargo.lock index a9aeffb2..1cd0c77d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -784,6 +784,7 @@ dependencies = [ name = "examples" version = "0.0.0" dependencies = [ + "serde_json", "swiftide", "tokio", "tracing-subscriber", diff --git a/examples/Cargo.toml b/examples/Cargo.toml index a9ec2e68..bb0e2feb 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" tokio = { version = "1.0", features = ["full"] } swiftide = { path = "../swiftide/", features = ["all"] } tracing-subscriber = "0.3" +serde_json = "1.0" [[example]] name = "ingest-codebase" @@ -16,3 +17,7 @@ path = "ingest_codebase.rs" [[example]] name = "fastembed" path = "fastembed.rs" + +[[example]] +name = "ingest-redis" +path = "ingest_into_redis.rs" diff --git a/examples/ingest_codebase.rs b/examples/ingest_codebase.rs index 8a37fa19..0cf476c2 100644 --- a/examples/ingest_codebase.rs +++ b/examples/ingest_codebase.rs @@ -21,7 +21,7 @@ use swiftide::{ ingestion, - integrations::{self, qdrant::Qdrant, redis::RedisNodeCache}, + integrations::{self, qdrant::Qdrant, redis::Redis}, loaders::FileLoader, transformers::{ChunkCode, Embed, MetadataQACode}, }; @@ -46,10 +46,7 @@ async fn main() -> Result<(), Box> { .to_owned(); ingestion::IngestionPipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"])) - .filter_cached(RedisNodeCache::try_from_url( - redis_url, - "swiftide-examples", - )?) + .filter_cached(Redis::try_from_url(redis_url, "swiftide-examples")?) .then(MetadataQACode::new(openai_client.clone())) .then_chunk(ChunkCode::try_for_language_and_chunk_size( "rust", diff --git a/examples/ingest_into_redis.rs b/examples/ingest_into_redis.rs new file mode 100644 index 00000000..6dcd8a26 --- /dev/null +++ b/examples/ingest_into_redis.rs @@ -0,0 +1,51 @@ +//! # [Swiftide] Ingesting the Swiftide itself example +//! +//! This example demonstrates how to ingest the Swiftide codebase itself. +//! Note that for it to work correctly you need to have OPENAI_API_KEY set, redis and qdrant +//! running. +//! +//! The pipeline will: +//! - Load all `.rs` files from the current directory +//! - Skip any nodes previously processed; hashes are based on the path and chunk (not the +//! metadata!) +//! - Run metadata QA on each chunk; generating questions and answers and adding metadata +//! - Chunk the code into pieces of 10 to 2048 bytes +//! - Embed the chunks in batches of 10, Metadata is embedded by default +//! - Store the nodes in Qdrant +//! +//! Note that metadata is copied over to smaller chunks when chunking. When making LLM requests +//! with lots of small chunks, consider the rate limits of the API. +//! +//! [Swiftide]: https://github.com/bosun-ai/swiftide +//! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples + +use swiftide::{ + ingestion, integrations::redis::Redis, loaders::FileLoader, transformers::ChunkCode, +}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + + let redis_url = std::env::var("REDIS_URL") + .as_deref() + .unwrap_or("redis://localhost:6379") + .to_owned(); + + ingestion::IngestionPipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"])) + .then_chunk(ChunkCode::try_for_language_and_chunk_size( + "rust", + 10..2048, + )?) + .then_store_with( + // By default the value is the full node serialized to JSON. + // We can customize this by providing a custom function. + Redis::try_build_from_url(&redis_url)? + .persist_value_fn(|node| Ok(serde_json::to_string(&node.metadata)?)) + .batch_size(50) + .build()?, + ) + .run() + .await?; + Ok(()) +} diff --git a/swiftide/src/ingestion/ingestion_node.rs b/swiftide/src/ingestion/ingestion_node.rs index 69e3df26..7e4d9ee1 100644 --- a/swiftide/src/ingestion/ingestion_node.rs +++ b/swiftide/src/ingestion/ingestion_node.rs @@ -23,12 +23,14 @@ use std::{ path::PathBuf, }; +use serde::{Deserialize, Serialize}; + /// Represents a unit of data in the ingestion process. /// /// `IngestionNode` encapsulates all necessary information for a single unit of data being processed /// in the ingestion pipeline. It includes fields for an identifier, file path, data chunk, optional /// vector representation, and metadata. -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] pub struct IngestionNode { /// Optional identifier for the node. pub id: Option, diff --git a/swiftide/src/integrations/redis/mod.rs b/swiftide/src/integrations/redis/mod.rs index 29ba808d..015e674a 100644 --- a/swiftide/src/integrations/redis/mod.rs +++ b/swiftide/src/integrations/redis/mod.rs @@ -1,12 +1,12 @@ //! This module provides the integration with Redis for caching nodes in the Swiftide system. //! -//! The primary component of this module is the `RedisNodeCache`, which is re-exported for use -//! in other parts of the system. The `RedisNodeCache` struct is responsible for managing and +//! The primary component of this module is the `Redis`, which is re-exported for use +//! in other parts of the system. The `Redis` struct is responsible for managing and //! caching nodes during the ingestion process, leveraging Redis for efficient storage and retrieval. //! //! # Overview //! -//! The `RedisNodeCache` struct provides methods for: +//! The `Redis` struct provides methods for: //! - Connecting to a Redis database //! - Checking if a node is cached //! - Setting a node in the cache @@ -14,6 +14,198 @@ //! //! This integration is essential for ensuring efficient node management and caching in the Swiftide system. +use anyhow::{Context as _, Result}; +use derive_builder::Builder; +use tokio::sync::RwLock; + +use crate::ingestion::IngestionNode; + mod node_cache; +mod persist; + +/// `Redis` provides a caching mechanism for nodes using Redis. +/// It helps in optimizing the ingestion process by skipping nodes that have already been processed. +/// +/// # Fields +/// +/// * `client` - The Redis client used to interact with the Redis server. +/// * `connection_manager` - Manages the Redis connections asynchronously. +/// * `key_prefix` - A prefix used for keys stored in Redis to avoid collisions. +#[derive(Builder)] +#[builder(pattern = "owned", setter(strip_option))] +pub struct Redis { + client: redis::Client, + #[builder(default, setter(skip))] + connection_manager: RwLock>, + #[builder(default)] + cache_key_prefix: String, + #[builder(default = "10")] + /// The batch size used for persisting nodes. Defaults to a safe 10. + batch_size: usize, + #[builder(default)] + /// Customize the key used for persisting nodes + persist_key_fn: Option Result>, + #[builder(default)] + /// Customize the value used for persisting nodes + persist_value_fn: Option Result>, +} + +impl Redis { + /// Creates a new `Redis` instance from a given Redis URL and key prefix. + /// + /// # Parameters + /// + /// * `url` - The URL of the Redis server. + /// * `prefix` - The prefix to be used for keys stored in Redis. + /// + /// # Returns + /// + /// A `Result` containing the `Redis` instance or an error if the client could not be created. + /// + /// # Errors + /// + /// Returns an error if the Redis client cannot be opened. + pub fn try_from_url(url: impl AsRef, prefix: impl AsRef) -> Result { + let client = redis::Client::open(url.as_ref()).context("Failed to open redis client")?; + Ok(Self { + client, + connection_manager: RwLock::new(None), + cache_key_prefix: prefix.as_ref().to_string(), + batch_size: 10, + persist_key_fn: None, + persist_value_fn: None, + }) + } + + pub fn try_build_from_url(url: impl AsRef) -> Result { + Ok(RedisBuilder::default() + .client(redis::Client::open(url.as_ref()).context("Failed to open redis client")?)) + } + + /// Builds a new `Redis` instance from the builder. + pub fn builder() -> RedisBuilder { + RedisBuilder::default() + } + + /// Lazily connects to the Redis server and returns the connection manager. + /// + /// # Returns + /// + /// An `Option` containing the `ConnectionManager` if the connection is successful, or `None` if it fails. + /// + /// # Errors + /// + /// Logs an error and returns `None` if the connection manager cannot be obtained. + async fn lazy_connect(&self) -> Option { + if self.connection_manager.read().await.is_none() { + let result = self.client.get_connection_manager().await; + if let Err(e) = result { + tracing::error!("Failed to get connection manager: {}", e); + return None; + } + let mut cm = self.connection_manager.write().await; + *cm = result.ok(); + } + + self.connection_manager.read().await.clone() + } + + /// Generates a Redis key for a given node using the key prefix and the node's hash. + /// + /// # Parameters + /// + /// * `node` - The node for which the key is to be generated. + /// + /// # Returns + /// + /// A `String` representing the Redis key for the node. + fn cache_key_for_node(&self, node: &IngestionNode) -> String { + format!("{}:{}", self.cache_key_prefix, node.calculate_hash()) + } + + /// Generates a key for a given node to be persisted in Redis. + fn persist_key_for_node(&self, node: &IngestionNode) -> Result { + if let Some(key_fn) = self.persist_key_fn { + key_fn(node) + } else { + let hash = node.calculate_hash(); + Ok(format!("{}:{}", node.path.to_string_lossy(), hash)) + } + } + + /// Generates a value for a given node to be persisted in Redis. + /// By default, the node is serialized as JSON. + /// If a custom function is provided, it is used to generate the value. + /// Otherwise, the node is serialized as JSON. + fn persist_value_for_node(&self, node: &IngestionNode) -> Result { + if let Some(value_fn) = self.persist_value_fn { + value_fn(node) + } else { + Ok(serde_json::to_string(node)?) + } + } + + /// Resets the cache by deleting all keys with the specified prefix. + /// This function is intended for testing purposes and is inefficient for production use. + /// + /// # Errors + /// + /// Panics if the keys cannot be retrieved or deleted. + #[allow(dead_code)] + async fn reset_cache(&self) { + if let Some(mut cm) = self.lazy_connect().await { + let keys: Vec = redis::cmd("KEYS") + .arg(format!("{}:*", self.cache_key_prefix)) + .query_async(&mut cm) + .await + .expect("Could not get keys"); + + for key in &keys { + let _: usize = redis::cmd("DEL") + .arg(key) + .query_async(&mut cm) + .await + .expect("Failed to reset cache"); + } + } + } + + /// Gets a node persisted in Redis using the GET command + /// Takes a node and returns a Result> + #[allow(dead_code)] + async fn get_node(&self, node: &IngestionNode) -> Result> { + if let Some(mut cm) = self.lazy_connect().await { + let key = self.persist_key_for_node(node)?; + let result: Option = redis::cmd("GET") + .arg(key) + .query_async(&mut cm) + .await + .context("Error getting from redis")?; + Ok(result) + } else { + anyhow::bail!("Failed to connect to Redis") + } + } +} + +// Redis CM does not implement debug +impl std::fmt::Debug for Redis { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Redis") + .field("client", &self.client) + .finish() + } +} -pub use node_cache::RedisNodeCache; +impl Clone for Redis { + fn clone(&self) -> Self { + Self { + client: self.client.clone(), + connection_manager: RwLock::new(None), + cache_key_prefix: self.cache_key_prefix.clone(), + batch_size: self.batch_size, + persist_key_fn: self.persist_key_fn, + persist_value_fn: self.persist_value_fn, + } + } +} diff --git a/swiftide/src/integrations/redis/node_cache.rs b/swiftide/src/integrations/redis/node_cache.rs index 40e7b821..85c8aab4 100644 --- a/swiftide/src/integrations/redis/node_cache.rs +++ b/swiftide/src/integrations/redis/node_cache.rs @@ -1,132 +1,12 @@ -use std::fmt::Debug; -use tokio::sync::RwLock; - -use anyhow::{Context as _, Result}; +use anyhow::Result; use async_trait::async_trait; use crate::{ingestion::IngestionNode, traits::NodeCache}; -/// `RedisNodeCache` provides a caching mechanism for nodes using Redis. -/// It helps in optimizing the ingestion process by skipping nodes that have already been processed. -/// -/// # Fields -/// -/// * `client` - The Redis client used to interact with the Redis server. -/// * `connection_manager` - Manages the Redis connections asynchronously. -/// * `key_prefix` - A prefix used for keys stored in Redis to avoid collisions. -pub struct RedisNodeCache { - client: redis::Client, - connection_manager: RwLock>, - key_prefix: String, -} - -impl RedisNodeCache { - /// Creates a new `RedisNodeCache` instance from a given Redis URL and key prefix. - /// - /// # Parameters - /// - /// * `url` - The URL of the Redis server. - /// * `prefix` - The prefix to be used for keys stored in Redis. - /// - /// # Returns - /// - /// A `Result` containing the `RedisNodeCache` instance or an error if the client could not be created. - /// - /// # Errors - /// - /// Returns an error if the Redis client cannot be opened. - pub fn try_from_url(url: impl AsRef, prefix: impl AsRef) -> Result { - let client = redis::Client::open(url.as_ref()).context("Failed to open redis client")?; - Ok(Self { - client, - connection_manager: RwLock::new(None), - key_prefix: prefix.as_ref().to_string(), - }) - } - - /// Lazily connects to the Redis server and returns the connection manager. - /// - /// # Returns - /// - /// An `Option` containing the `ConnectionManager` if the connection is successful, or `None` if it fails. - /// - /// # Errors - /// - /// Logs an error and returns `None` if the connection manager cannot be obtained. - async fn lazy_connect(&self) -> Option { - if self.connection_manager.read().await.is_none() { - let result = self.client.get_connection_manager().await; - if let Err(e) = result { - tracing::error!("Failed to get connection manager: {}", e); - return None; - } - let mut cm = self.connection_manager.write().await; - *cm = result.ok(); - } - - self.connection_manager.read().await.clone() - } - - /// Generates a Redis key for a given node using the key prefix and the node's hash. - /// - /// # Parameters - /// - /// * `node` - The node for which the key is to be generated. - /// - /// # Returns - /// - /// A `String` representing the Redis key for the node. - fn key_for_node(&self, node: &IngestionNode) -> String { - format!("{}:{}", self.key_prefix, node.calculate_hash()) - } - - /// Resets the cache by deleting all keys with the specified prefix. - /// This function is intended for testing purposes and is inefficient for production use. - /// - /// # Errors - /// - /// Panics if the keys cannot be retrieved or deleted. - #[allow(dead_code)] - async fn reset_cache(&self) { - if let Some(mut cm) = self.lazy_connect().await { - let keys: Vec = redis::cmd("KEYS") - .arg(format!("{}:*", self.key_prefix)) - .query_async(&mut cm) - .await - .expect("Could not get keys"); - - for key in &keys { - let _: usize = redis::cmd("DEL") - .arg(key) - .query_async(&mut cm) - .await - .expect("Failed to reset cache"); - } - } - } -} - -// Redis CM does not implement debug -impl Debug for RedisNodeCache { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Redis") - .field("client", &self.client) - .finish() - } -} - -impl Clone for RedisNodeCache { - fn clone(&self) -> Self { - Self { - client: self.client.clone(), - connection_manager: RwLock::new(None), - key_prefix: self.key_prefix.clone(), - } - } -} +use super::Redis; #[async_trait] -impl NodeCache for RedisNodeCache { +impl NodeCache for Redis { /// Checks if a node is present in the cache. /// /// # Parameters @@ -144,7 +24,7 @@ impl NodeCache for RedisNodeCache { async fn get(&self, node: &IngestionNode) -> bool { let cache_result = if let Some(mut cm) = self.lazy_connect().await { let result = redis::cmd("EXISTS") - .arg(self.key_for_node(node)) + .arg(self.cache_key_for_node(node)) .query_async(&mut cm) .await; @@ -182,7 +62,7 @@ impl NodeCache for RedisNodeCache { async fn set(&self, node: &IngestionNode) { if let Some(mut cm) = self.lazy_connect().await { let result: Result<(), redis::RedisError> = redis::cmd("SET") - .arg(self.key_for_node(node)) + .arg(self.cache_key_for_node(node)) .arg(1) .query_async(&mut cm) .await; @@ -214,7 +94,7 @@ mod tests { let host = redis.get_host().await.unwrap(); let port = redis.get_host_port_ipv4(6379).await.unwrap(); - let cache = RedisNodeCache::try_from_url(format!("redis://{host}:{port}"), "test") + let cache = Redis::try_from_url(format!("redis://{host}:{port}"), "test") .expect("Could not build redis client"); cache.reset_cache().await; diff --git a/swiftide/src/integrations/redis/persist.rs b/swiftide/src/integrations/redis/persist.rs new file mode 100644 index 00000000..ef2bf9de --- /dev/null +++ b/swiftide/src/integrations/redis/persist.rs @@ -0,0 +1,186 @@ +use anyhow::{Context as _, Result}; +use async_trait::async_trait; +use futures_util::{stream, StreamExt}; + +use crate::{ + ingestion::{IngestionNode, IngestionStream}, + Persist, +}; + +use super::Redis; + +#[async_trait] +impl Persist for Redis { + async fn setup(&self) -> Result<()> { + Ok(()) + } + + fn batch_size(&self) -> Option { + Some(self.batch_size) + } + + /// Stores a node in Redis using the SET command. + /// + /// By default nodes are stored with the path and hash as key and the node serialized as JSON as value. + /// + /// You can customize the key and value used for storing nodes by setting the `persist_key_fn` and `persist_value_fn` fields. + async fn store(&self, node: IngestionNode) -> Result { + if let Some(mut cm) = self.lazy_connect().await { + redis::cmd("SET") + .arg(self.persist_key_for_node(&node)?) + .arg(self.persist_value_for_node(&node)?) + .query_async(&mut cm) + .await + .context("Error persisting to redis")?; + + Ok(node) + } else { + anyhow::bail!("Failed to connect to Redis") + } + } + + /// Stores a batch of nodes in Redis using the MSET command. + /// + /// By default nodes are stored with the path and hash as key and the node serialized as JSON as value. + /// + /// You can customize the key and value used for storing nodes by setting the `persist_key_fn` and `persist_value_fn` fields. + async fn batch_store(&self, nodes: Vec) -> IngestionStream { + // use mset for batch store + if let Some(mut cm) = self.lazy_connect().await { + let args = nodes + .iter() + .map(|node| -> Result> { + let key = self.persist_key_for_node(node)?; + let value = self.persist_value_for_node(node)?; + + Ok(vec![key, value]) + }) + .collect::>>(); + + if args.is_err() { + return stream::iter(vec![Err(args.unwrap_err())]).boxed(); + } + + let args = args.unwrap(); + + let result: Result<()> = redis::cmd("MSET") + .arg(args) + .query_async(&mut cm) + .await + .context("Error persisting to redis"); + + if result.is_ok() { + stream::iter(nodes.into_iter().map(Ok)).boxed() + } else { + stream::iter(vec![Err(result.unwrap_err())]).boxed() + } + } else { + stream::iter(vec![Err(anyhow::anyhow!("Failed to connect to Redis"))]).boxed() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures_util::TryStreamExt; + use std::collections::HashMap; + use testcontainers::{runners::AsyncRunner, ContainerAsync, GenericImage}; + + async fn start_redis() -> ContainerAsync { + testcontainers::GenericImage::new("redis", "7.2.4") + .with_exposed_port(6379.into()) + .with_wait_for(testcontainers::core::WaitFor::message_on_stdout( + "Ready to accept connections", + )) + .start() + .await + .expect("Redis started") + } + + #[test_log::test(tokio::test)] + async fn test_redis_persist() { + let redis_container = start_redis().await; + + let host = redis_container.get_host().await.unwrap(); + let port = redis_container.get_host_port_ipv4(6379).await.unwrap(); + let redis = Redis::try_build_from_url(format!("redis://{host}:{port}")) + .unwrap() + .build() + .unwrap(); + + let node = IngestionNode { + id: Some(1), + path: "test".into(), + chunk: "chunk".into(), + vector: None, + metadata: HashMap::new(), + }; + + redis.store(node.clone()).await.unwrap(); + let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap()); + + assert_eq!(node, stored_node.unwrap()); + } + + // test batch store + #[test_log::test(tokio::test)] + async fn test_redis_batch_persist() { + let redis_container = start_redis().await; + let host = redis_container.get_host().await.unwrap(); + let port = redis_container.get_host_port_ipv4(6379).await.unwrap(); + let redis = Redis::try_build_from_url(format!("redis://{host}:{port}")) + .unwrap() + .batch_size(20) + .build() + .unwrap(); + let nodes = vec![ + IngestionNode { + id: Some(1), + path: "test".into(), + ..Default::default() + }, + IngestionNode { + id: Some(2), + path: "other".into(), + ..Default::default() + }, + ]; + + let stream = redis.batch_store(nodes).await; + let streamed_nodes: Vec = stream.try_collect().await.unwrap(); + + assert_eq!(streamed_nodes.len(), 2); + + for node in streamed_nodes { + let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap()); + assert_eq!(node, stored_node.unwrap()) + } + } + + #[test_log::test(tokio::test)] + async fn test_redis_custom_persist() { + let redis_container = start_redis().await; + let host = redis_container.get_host().await.unwrap(); + let port = redis_container.get_host_port_ipv4(6379).await.unwrap(); + let redis = Redis::try_build_from_url(format!("redis://{host}:{port}")) + .unwrap() + .persist_key_fn(|_node| Ok("test".to_string())) + .persist_value_fn(|_node| Ok("hello world".to_string())) + .build() + .unwrap(); + let node = IngestionNode { + id: Some(1), + ..Default::default() + }; + + redis.store(node.clone()).await.unwrap(); + let stored_node = redis.get_node(&node).await.unwrap(); + + assert_eq!(stored_node.unwrap(), "hello world"); + assert_eq!( + redis.persist_key_for_node(&node).unwrap(), + "test".to_string() + ) + } +} diff --git a/swiftide/tests/ingestion_pipeline.rs b/swiftide/tests/ingestion_pipeline.rs index f81959c7..754daa68 100644 --- a/swiftide/tests/ingestion_pipeline.rs +++ b/swiftide/tests/ingestion_pipeline.rs @@ -135,9 +135,7 @@ async fn test_ingestion_pipeline() { IngestionPipeline::from_loader(FileLoader::new(tempdir.path()).with_extensions(&["rs"])) .then_chunk(transformers::ChunkCode::try_for_language("rust").unwrap()) .then(transformers::MetadataQACode::new(openai_client.clone())) - .filter_cached( - integrations::redis::RedisNodeCache::try_from_url(&redis_url, "prefix").unwrap(), - ) + .filter_cached(integrations::redis::Redis::try_from_url(&redis_url, "prefix").unwrap()) .then_in_batch(1, transformers::Embed::new(openai_client.clone())) .then_store_with( integrations::qdrant::Qdrant::try_from_url(&qdrant_url)