diff --git a/examples/aws_bedrock.rs b/examples/aws_bedrock.rs index d5881bda..be95e711 100644 --- a/examples/aws_bedrock.rs +++ b/examples/aws_bedrock.rs @@ -44,8 +44,7 @@ async fn main() -> Result<(), Box> { .get_all_values() .await .iter() - .filter_map(|n| n.metadata.get("Summary")) - .cloned() + .filter_map(|n| n.metadata.get("Summary").map(|v| v.to_string())) .collect::>() .join("\n---\n") ); diff --git a/examples/index_groq.rs b/examples/index_groq.rs index 990bc78c..95c8adad 100644 --- a/examples/index_groq.rs +++ b/examples/index_groq.rs @@ -47,7 +47,7 @@ async fn main() -> Result<(), Box> { .get_all_values() .await .into_iter() - .flat_map(|n| n.metadata.into_values()) + .flat_map(|n| n.metadata.into_values().map(|v| v.to_string())) .collect::>() .join("\n") ); diff --git a/swiftide-core/src/lib.rs b/swiftide-core/src/lib.rs index 0005d56a..6daa4c9f 100644 --- a/swiftide-core/src/lib.rs +++ b/swiftide-core/src/lib.rs @@ -10,6 +10,8 @@ pub mod type_aliases; pub mod prompt; pub use type_aliases::*; +mod metadata; + /// All traits are available from the root pub use crate::indexing_traits::*; pub use crate::query_traits::*; @@ -17,6 +19,7 @@ pub use crate::query_traits::*; pub mod indexing { pub use crate::indexing_stream::IndexingStream; pub use crate::indexing_traits::*; + pub use crate::metadata::*; pub use crate::node::*; } diff --git a/swiftide-core/src/metadata.rs b/swiftide-core/src/metadata.rs new file mode 100644 index 00000000..838b7768 --- /dev/null +++ b/swiftide-core/src/metadata.rs @@ -0,0 +1,176 @@ +//! Metadata is a key-value store for indexation nodes +//! +//! Typically metadata is used to extract or generate additional information about the node +//! +//! Internally it uses a `BTreeMap` to store the key-value pairs, to ensure the data is sorted. +use std::collections::{btree_map::IntoValues, BTreeMap}; + +use serde::Deserializer; + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct Metadata { + inner: BTreeMap, +} + +impl Metadata { + pub fn iter(&self) -> impl Iterator { + self.inner.iter() + } + + pub fn insert(&mut self, key: K, value: V) + where + K: Into, + V: Into, + { + self.inner.insert(key.into(), value.into()); + } + + pub fn get(&self, key: impl AsRef) -> Option<&serde_json::Value> { + self.inner.get(key.as_ref()) + } + + pub fn into_values(self) -> IntoValues { + self.inner.into_values() + } +} + +impl Extend<(K, V)> for Metadata +where + K: Into, + V: Into, +{ + fn extend>(&mut self, iter: T) { + self.inner + .extend(iter.into_iter().map(|(k, v)| (k.into(), v.into()))); + } +} + +impl From> for Metadata +where + K: Into, + V: Into, +{ + fn from(items: Vec<(K, V)>) -> Self { + let inner = items + .into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(); + Metadata { inner } + } +} + +impl<'a, K, V> From<&'a [(K, V)]> for Metadata +where + K: Into + Clone, + V: Into + Clone, +{ + fn from(items: &'a [(K, V)]) -> Self { + let inner = items + .iter() + .cloned() + .map(|(k, v)| (k.into(), v.into())) + .collect(); + Metadata { inner } + } +} + +impl From<[(K, V); N]> for Metadata +where + K: Ord + Into, + V: Into, +{ + fn from(mut arr: [(K, V); N]) -> Self { + if N == 0 { + return Metadata { + inner: BTreeMap::new(), + }; + } + arr.sort_by(|a, b| a.0.cmp(&b.0)); + let inner: BTreeMap = + arr.into_iter().map(|(k, v)| (k.into(), v.into())).collect(); + Metadata { inner } + } +} + +impl IntoIterator for Metadata { + type Item = (String, serde_json::Value); + type IntoIter = std::collections::btree_map::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} + +impl<'iter> IntoIterator for &'iter Metadata { + type Item = (&'iter String, &'iter serde_json::Value); + type IntoIter = std::collections::btree_map::Iter<'iter, String, serde_json::Value>; + fn into_iter(self) -> Self::IntoIter { + self.inner.iter() + } +} + +impl<'de> serde::Deserialize<'de> for Metadata { + fn deserialize>(deserializer: D) -> Result { + BTreeMap::deserialize(deserializer).map(|inner| Metadata { inner }) + } +} + +impl serde::Serialize for Metadata { + fn serialize(&self, serializer: S) -> Result { + self.inner.serialize(serializer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_insert_and_get() { + let mut metadata = Metadata::default(); + let key = "key"; + let value = "value"; + metadata.insert(key, "value"); + + assert_eq!(metadata.get(key).unwrap().as_str(), Some(value)); + } + + #[test] + fn test_iter() { + let mut metadata = Metadata::default(); + metadata.insert("key1", json!("value1")); + metadata.insert("key2", json!("value2")); + + let mut iter = metadata.iter(); + assert_eq!(iter.next(), Some((&"key1".to_string(), &json!("value1")))); + assert_eq!(iter.next(), Some((&"key2".to_string(), &json!("value2")))); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_extend() { + let mut metadata = Metadata::default(); + metadata.extend(vec![("key1", json!("value1")), ("key2", json!("value2"))]); + + assert_eq!(metadata.get("key1"), Some(&json!("value1"))); + assert_eq!(metadata.get("key2"), Some(&json!("value2"))); + } + + #[test] + fn test_from_vec() { + let metadata = Metadata::from(vec![("key1", json!("value1")), ("key2", json!("value2"))]); + + assert_eq!(metadata.get("key1"), Some(&json!("value1"))); + assert_eq!(metadata.get("key2"), Some(&json!("value2"))); + } + + #[test] + fn test_into_values() { + let mut metadata = Metadata::default(); + metadata.insert("key1", json!("value1")); + metadata.insert("key2", json!("value2")); + + let values: Vec<_> = metadata.into_values().collect(); + assert_eq!(values, vec![json!("value1"), json!("value2")]); + } +} diff --git a/swiftide-core/src/node.rs b/swiftide-core/src/node.rs index 9c69efed..e567e3bc 100644 --- a/swiftide-core/src/node.rs +++ b/swiftide-core/src/node.rs @@ -18,7 +18,7 @@ //! individual units of data. It is particularly useful in scenarios where metadata and data chunks //! need to be processed together. use std::{ - collections::{BTreeMap, HashMap}, + collections::HashMap, fmt::Debug, hash::{Hash, Hasher}, path::PathBuf, @@ -27,6 +27,8 @@ use std::{ use itertools::Itertools; use serde::{Deserialize, Serialize}; +use crate::metadata::Metadata; + /// Represents a unit of data in the indexing process. /// /// `Node` encapsulates all necessary information for a single unit of data being processed @@ -43,7 +45,7 @@ pub struct Node { /// Optional vector representation of embedded data. pub vectors: Option>>, /// Metadata associated with the node. - pub metadata: BTreeMap, + pub metadata: Metadata, /// Mode of embedding data Chunk and Metadata pub embed_mode: EmbedMode, } @@ -99,7 +101,10 @@ impl Node { if self.embed_mode == EmbedMode::PerField || self.embed_mode == EmbedMode::Both { embeddables.push((EmbeddedField::Chunk, self.chunk.clone())); for (name, value) in &self.metadata { - embeddables.push((EmbeddedField::Metadata(name.clone()), value.clone())); + let value = value + .as_str() + .map_or_else(|| value.to_string(), ToString::to_string); + embeddables.push((EmbeddedField::Metadata(name.clone()), value)); } } @@ -119,7 +124,13 @@ impl Node { let metadata = self .metadata .iter() - .map(|(k, v)| format!("{k}: {v}")) + .map(|(k, v)| { + let v = v + .as_str() + .map_or_else(|| v.to_string(), ToString::to_string); + + format!("{k}: {v}") + }) .collect::>() .join("\n"); diff --git a/swiftide-indexing/src/transformers/embed.rs b/swiftide-indexing/src/transformers/embed.rs index 294ddf00..a191ba46 100644 --- a/swiftide-indexing/src/transformers/embed.rs +++ b/swiftide-indexing/src/transformers/embed.rs @@ -115,13 +115,11 @@ impl BatchableTransformer for Embed { #[cfg(test)] mod tests { - use swiftide_core::indexing::{EmbedMode, EmbeddedField, Node}; + use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, Node}; use swiftide_core::{BatchableTransformer, MockEmbeddingModel}; use super::Embed; - use std::collections::HashMap; - use futures_util::StreamExt; use mockall::predicate::*; use test_case::test_case; @@ -130,7 +128,7 @@ mod tests { struct TestData<'a> { pub embed_mode: EmbedMode, pub chunk: &'a str, - pub metadata: HashMap<&'a str, &'a str>, + pub metadata: Metadata, pub expected_embedables: Vec<&'a str>, pub expected_vectors: Vec<(EmbeddedField, Vec)>, } @@ -139,14 +137,14 @@ mod tests { TestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "chunk_1", - metadata: HashMap::from([("meta_1", "prompt_1")]), + metadata: Metadata::from([("meta_1", "prompt_1")]), expected_embedables: vec!["meta_1: prompt_1\nchunk_1"], expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])] }, TestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "chunk_2", - metadata: HashMap::from([("meta_2", "prompt_2")]), + metadata: Metadata::from([("meta_2", "prompt_2")]), expected_embedables: vec!["meta_2: prompt_2\nchunk_2"], expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])] } @@ -155,7 +153,7 @@ mod tests { TestData { embed_mode: EmbedMode::PerField, chunk: "chunk_1", - metadata: HashMap::from([("meta_1", "prompt 1")]), + metadata: Metadata::from([("meta_1", "prompt 1")]), expected_embedables: vec!["chunk_1", "prompt 1"], expected_vectors: vec![ (EmbeddedField::Chunk, vec![10f32]), @@ -165,7 +163,7 @@ mod tests { TestData { embed_mode: EmbedMode::PerField, chunk: "chunk_2", - metadata: HashMap::from([("meta_2", "prompt 2")]), + metadata: Metadata::from([("meta_2", "prompt 2")]), expected_embedables: vec!["chunk_2", "prompt 2"], expected_vectors: vec![ (EmbeddedField::Chunk, vec![20f32]), @@ -177,7 +175,7 @@ mod tests { TestData { embed_mode: EmbedMode::Both, chunk: "chunk_1", - metadata: HashMap::from([("meta_1", "prompt 1")]), + metadata: Metadata::from([("meta_1", "prompt 1")]), expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"], expected_vectors: vec![ (EmbeddedField::Combined, vec![10f32]), @@ -188,7 +186,7 @@ mod tests { TestData { embed_mode: EmbedMode::Both, chunk: "chunk_2", - metadata: HashMap::from([("meta_2", "prompt 2")]), + metadata: Metadata::from([("meta_2", "prompt 2")]), expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"], expected_vectors: vec![ (EmbeddedField::Combined, vec![20f32]), @@ -201,7 +199,7 @@ mod tests { TestData { embed_mode: EmbedMode::Both, chunk: "chunk_1", - metadata: HashMap::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]), + metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]), expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"], expected_vectors: vec![ (EmbeddedField::Combined, vec![10f32]), @@ -214,7 +212,7 @@ mod tests { TestData { embed_mode: EmbedMode::Both, chunk: "chunk_2", - metadata: HashMap::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]), + metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]), expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"], expected_vectors: vec![ (EmbeddedField::Combined, vec![20f32]), @@ -232,11 +230,7 @@ mod tests { .iter() .map(|data| Node { chunk: data.chunk.into(), - metadata: data - .metadata - .iter() - .map(|(k, v)| ((*k).to_string(), (*v).to_string())) - .collect(), + metadata: data.metadata.clone(), embed_mode: data.embed_mode, ..Default::default() }) diff --git a/swiftide-indexing/src/transformers/metadata_keywords.rs b/swiftide-indexing/src/transformers/metadata_keywords.rs index 87f8d957..966b7ec2 100644 --- a/swiftide-indexing/src/transformers/metadata_keywords.rs +++ b/swiftide-indexing/src/transformers/metadata_keywords.rs @@ -96,7 +96,7 @@ impl Transformer for MetadataKeywords { let prompt = self.prompt_template.to_prompt().with_node(&node); let response = self.client.prompt(prompt).await?; - node.metadata.insert(NAME.into(), response); + node.metadata.insert(NAME, response); Ok(node) } diff --git a/swiftide-indexing/src/transformers/metadata_qa_code.rs b/swiftide-indexing/src/transformers/metadata_qa_code.rs index 561d1fb1..7d01dbdd 100644 --- a/swiftide-indexing/src/transformers/metadata_qa_code.rs +++ b/swiftide-indexing/src/transformers/metadata_qa_code.rs @@ -101,7 +101,7 @@ impl Transformer for MetadataQACode { let response = self.client.prompt(prompt).await?; - node.metadata.insert(NAME.into(), response); + node.metadata.insert(NAME, response); Ok(node) } diff --git a/swiftide-indexing/src/transformers/metadata_qa_text.rs b/swiftide-indexing/src/transformers/metadata_qa_text.rs index 0d073c44..4b7538eb 100644 --- a/swiftide-indexing/src/transformers/metadata_qa_text.rs +++ b/swiftide-indexing/src/transformers/metadata_qa_text.rs @@ -104,7 +104,7 @@ impl Transformer for MetadataQAText { let response = self.client.prompt(prompt).await?; - node.metadata.insert(NAME.into(), response); + node.metadata.insert(NAME, response); Ok(node) } diff --git a/swiftide-indexing/src/transformers/metadata_refs_defs_code.rs b/swiftide-indexing/src/transformers/metadata_refs_defs_code.rs index 76220979..7d6aadc8 100644 --- a/swiftide-indexing/src/transformers/metadata_refs_defs_code.rs +++ b/swiftide-indexing/src/transformers/metadata_refs_defs_code.rs @@ -24,12 +24,12 @@ //! node = transformer.transform_node(node).await.unwrap(); //! //! assert_eq!( -//! node.metadata.get(NAME_REFERENCES), -//! Some(&"println".to_string()) +//! node.metadata.get(NAME_REFERENCES).unwrap().as_str().unwrap(), +//! "println" //! ); //! assert_eq!( -//! node.metadata.get(NAME_DEFINITIONS), -//! Some(&"main".to_string()) +//! node.metadata.get(NAME_DEFINITIONS).unwrap().as_str().unwrap(), +//! "main" //! ); //! # Ok(()) //! # } @@ -129,12 +129,12 @@ mod test { node = transformer.transform_node(node).await.unwrap(); assert_eq!( - node.metadata.get(NAME_REFERENCES), - Some(&"println".to_string()) + node.metadata.get(NAME_REFERENCES).unwrap().as_str(), + "println".into() ); assert_eq!( - node.metadata.get(NAME_DEFINITIONS), - Some(&"main".to_string()) + node.metadata.get(NAME_DEFINITIONS).unwrap().as_str(), + "main".into() ); } } diff --git a/swiftide-indexing/src/transformers/metadata_summary.rs b/swiftide-indexing/src/transformers/metadata_summary.rs index f66ac5c2..6c820fa3 100644 --- a/swiftide-indexing/src/transformers/metadata_summary.rs +++ b/swiftide-indexing/src/transformers/metadata_summary.rs @@ -97,7 +97,7 @@ impl Transformer for MetadataSummary { let response = self.client.prompt(prompt).await?; - node.metadata.insert(NAME.into(), response); + node.metadata.insert(NAME, response); Ok(node) } diff --git a/swiftide-indexing/src/transformers/metadata_title.rs b/swiftide-indexing/src/transformers/metadata_title.rs index 81186f3f..ba98efa1 100644 --- a/swiftide-indexing/src/transformers/metadata_title.rs +++ b/swiftide-indexing/src/transformers/metadata_title.rs @@ -97,7 +97,7 @@ impl Transformer for MetadataTitle { let response = self.client.prompt(prompt).await?; - node.metadata.insert(NAME.into(), response); + node.metadata.insert(NAME, response); Ok(node) } diff --git a/swiftide-integrations/src/qdrant/indexing_node.rs b/swiftide-integrations/src/qdrant/indexing_node.rs index 1faf457a..4205f446 100644 --- a/swiftide-integrations/src/qdrant/indexing_node.rs +++ b/swiftide-integrations/src/qdrant/indexing_node.rs @@ -4,7 +4,10 @@ //! data compatibility with Qdrant's required format. use anyhow::{bail, Result}; -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + string::ToString, +}; use qdrant_client::{ client::Payload, @@ -36,19 +39,24 @@ impl TryInto for NodeWithVectors { // Extend the metadata with additional information. node.metadata.extend([ - ("path".to_string(), node.path.to_string_lossy().to_string()), - ("content".to_string(), node.chunk), - ( - "last_updated_at".to_string(), - chrono::Utc::now().to_rfc3339(), - ), + ("path", node.path.to_string_lossy().to_string()), + ("content", node.chunk), + ("last_updated_at", chrono::Utc::now().to_rfc3339()), ]); // Create a payload compatible with Qdrant's API. let payload: Payload = node .metadata .iter() - .map(|(k, v)| (k.as_str(), Value::from(v.as_str()))) + .map(|(k, v)| { + ( + k.as_str(), + Value::from( + v.as_str() + .map_or_else(|| v.to_string(), ToString::to_string), + ), + ) + }) .collect::>() .into(); @@ -84,21 +92,21 @@ fn try_create_vectors( #[cfg(test)] mod tests { - use std::collections::{BTreeMap, HashMap, HashSet}; + use std::collections::{HashMap, HashSet}; use qdrant_client::qdrant::{ vectors::VectorsOptions, NamedVectors, PointId, PointStruct, Value, Vector, Vectors, }; + use swiftide_core::indexing::{EmbeddedField, Metadata, Node}; use test_case::test_case; use crate::qdrant::indexing_node::NodeWithVectors; - use swiftide_core::indexing::{EmbedMode, EmbeddedField, Node}; #[test_case( Node { id: Some(1), path: "/path".into(), chunk: "data".into(), vectors: Some(HashMap::from([(EmbeddedField::Chunk, vec![1.0])])), - metadata: BTreeMap::from([("m1".into(), "mv1".into())]), - embed_mode: EmbedMode::SingleWithMetadata + metadata: Metadata::from([("m1", "mv1")]), + embed_mode: swiftide_core::indexing::EmbedMode::SingleWithMetadata }, HashSet::from([EmbeddedField::Combined]), PointStruct { id: Some(PointId::from(6_516_159_902_038_153_111)), payload: HashMap::from([ @@ -115,8 +123,8 @@ mod tests { (EmbeddedField::Chunk, vec![1.0]), (EmbeddedField::Metadata("m1".into()), vec![2.0]) ])), - metadata: BTreeMap::from([("m1".into(), "mv1".into())]), - embed_mode: EmbedMode::PerField + metadata: Metadata::from([("m1", "mv1")]), + embed_mode: swiftide_core::indexing::EmbedMode::PerField }, HashSet::from([EmbeddedField::Chunk, EmbeddedField::Metadata("m1".into())]), PointStruct { id: Some(PointId::from(6_516_159_902_038_153_111)), payload: HashMap::from([ @@ -142,8 +150,8 @@ mod tests { (EmbeddedField::Metadata("m1".into()), vec![1.0]), (EmbeddedField::Metadata("m2".into()), vec![2.0]) ])), - metadata: BTreeMap::from([("m1".into(), "mv1".into()), ("m2".into(), "mv2".into())]), - embed_mode: EmbedMode::Both + metadata: Metadata::from([("m1", "mv1"), ("m2", "mv2")]), + embed_mode: swiftide_core::indexing::EmbedMode::Both }, HashSet::from([EmbeddedField::Combined]), PointStruct { id: Some(PointId::from(6_516_159_902_038_153_111)), payload: HashMap::from([ diff --git a/swiftide/tests/indexing_pipeline.rs b/swiftide/tests/indexing_pipeline.rs index 4f1d31de..0975a30b 100644 --- a/swiftide/tests/indexing_pipeline.rs +++ b/swiftide/tests/indexing_pipeline.rs @@ -54,6 +54,7 @@ async fn test_indexing_pipeline() { .then(transformers::MetadataQACode::new(openai_client.clone())) .filter_cached(integrations::redis::Redis::try_from_url(&redis_url, "prefix").unwrap()) .then_in_batch(1, transformers::Embed::new(openai_client.clone())) + .log_nodes() .then_store_with( integrations::qdrant::Qdrant::try_from_url(&qdrant_url) .unwrap() @@ -104,6 +105,7 @@ async fn test_indexing_pipeline() { let first = search_response.result.first().unwrap(); + dbg!(first); assert!(first .payload .get("path")