From ec1fb04573ab75fe140cbeff17bc3179e316ff0c Mon Sep 17 00:00:00 2001 From: Timon Vonk Date: Sun, 28 Jul 2024 15:55:04 +0200 Subject: [PATCH] feat(indexing): Metadata as first class citizen (#204) Adds our own implementation for metadata, internally still using a BTreeMap. The Value type is now a `serde_json::Value` enum. This allows us to store the metadata in the same format as the rest of the document, and also allows us to use values programmatically later. As is, all current meta data is still stored as Strings. Closes #162 --- examples/aws_bedrock.rs | 3 +- examples/index_groq.rs | 2 +- swiftide-core/src/lib.rs | 3 + swiftide-core/src/metadata.rs | 176 ++++++++++++++++++ swiftide-core/src/node.rs | 19 +- swiftide-indexing/src/transformers/embed.rs | 28 ++- .../src/transformers/metadata_keywords.rs | 2 +- .../src/transformers/metadata_qa_code.rs | 2 +- .../src/transformers/metadata_qa_text.rs | 2 +- .../transformers/metadata_refs_defs_code.rs | 16 +- .../src/transformers/metadata_summary.rs | 2 +- .../src/transformers/metadata_title.rs | 2 +- .../src/qdrant/indexing_node.rs | 40 ++-- swiftide/tests/indexing_pipeline.rs | 2 + 14 files changed, 246 insertions(+), 53 deletions(-) create mode 100644 swiftide-core/src/metadata.rs diff --git a/examples/aws_bedrock.rs b/examples/aws_bedrock.rs index d5881bda..be95e711 100644 --- a/examples/aws_bedrock.rs +++ b/examples/aws_bedrock.rs @@ -44,8 +44,7 @@ async fn main() -> Result<(), Box> { .get_all_values() .await .iter() - .filter_map(|n| n.metadata.get("Summary")) - .cloned() + .filter_map(|n| n.metadata.get("Summary").map(|v| v.to_string())) .collect::>() .join("\n---\n") ); diff --git a/examples/index_groq.rs b/examples/index_groq.rs index 990bc78c..95c8adad 100644 --- a/examples/index_groq.rs +++ b/examples/index_groq.rs @@ -47,7 +47,7 @@ async fn main() -> Result<(), Box> { .get_all_values() .await .into_iter() - .flat_map(|n| n.metadata.into_values()) + .flat_map(|n| n.metadata.into_values().map(|v| v.to_string())) .collect::>() .join("\n") ); diff --git a/swiftide-core/src/lib.rs b/swiftide-core/src/lib.rs index 0005d56a..6daa4c9f 100644 --- a/swiftide-core/src/lib.rs +++ b/swiftide-core/src/lib.rs @@ -10,6 +10,8 @@ pub mod type_aliases; pub mod prompt; pub use type_aliases::*; +mod metadata; + /// All traits are available from the root pub use crate::indexing_traits::*; pub use crate::query_traits::*; @@ -17,6 +19,7 @@ pub use crate::query_traits::*; pub mod indexing { pub use crate::indexing_stream::IndexingStream; pub use crate::indexing_traits::*; + pub use crate::metadata::*; pub use crate::node::*; } diff --git a/swiftide-core/src/metadata.rs b/swiftide-core/src/metadata.rs new file mode 100644 index 00000000..838b7768 --- /dev/null +++ b/swiftide-core/src/metadata.rs @@ -0,0 +1,176 @@ +//! Metadata is a key-value store for indexation nodes +//! +//! Typically metadata is used to extract or generate additional information about the node +//! +//! Internally it uses a `BTreeMap` to store the key-value pairs, to ensure the data is sorted. +use std::collections::{btree_map::IntoValues, BTreeMap}; + +use serde::Deserializer; + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct Metadata { + inner: BTreeMap, +} + +impl Metadata { + pub fn iter(&self) -> impl Iterator { + self.inner.iter() + } + + pub fn insert(&mut self, key: K, value: V) + where + K: Into, + V: Into, + { + self.inner.insert(key.into(), value.into()); + } + + pub fn get(&self, key: impl AsRef) -> Option<&serde_json::Value> { + self.inner.get(key.as_ref()) + } + + pub fn into_values(self) -> IntoValues { + self.inner.into_values() + } +} + +impl Extend<(K, V)> for Metadata +where + K: Into, + V: Into, +{ + fn extend>(&mut self, iter: T) { + self.inner + .extend(iter.into_iter().map(|(k, v)| (k.into(), v.into()))); + } +} + +impl From> for Metadata +where + K: Into, + V: Into, +{ + fn from(items: Vec<(K, V)>) -> Self { + let inner = items + .into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(); + Metadata { inner } + } +} + +impl<'a, K, V> From<&'a [(K, V)]> for Metadata +where + K: Into + Clone, + V: Into + Clone, +{ + fn from(items: &'a [(K, V)]) -> Self { + let inner = items + .iter() + .cloned() + .map(|(k, v)| (k.into(), v.into())) + .collect(); + Metadata { inner } + } +} + +impl From<[(K, V); N]> for Metadata +where + K: Ord + Into, + V: Into, +{ + fn from(mut arr: [(K, V); N]) -> Self { + if N == 0 { + return Metadata { + inner: BTreeMap::new(), + }; + } + arr.sort_by(|a, b| a.0.cmp(&b.0)); + let inner: BTreeMap = + arr.into_iter().map(|(k, v)| (k.into(), v.into())).collect(); + Metadata { inner } + } +} + +impl IntoIterator for Metadata { + type Item = (String, serde_json::Value); + type IntoIter = std::collections::btree_map::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} + +impl<'iter> IntoIterator for &'iter Metadata { + type Item = (&'iter String, &'iter serde_json::Value); + type IntoIter = std::collections::btree_map::Iter<'iter, String, serde_json::Value>; + fn into_iter(self) -> Self::IntoIter { + self.inner.iter() + } +} + +impl<'de> serde::Deserialize<'de> for Metadata { + fn deserialize>(deserializer: D) -> Result { + BTreeMap::deserialize(deserializer).map(|inner| Metadata { inner }) + } +} + +impl serde::Serialize for Metadata { + fn serialize(&self, serializer: S) -> Result { + self.inner.serialize(serializer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_insert_and_get() { + let mut metadata = Metadata::default(); + let key = "key"; + let value = "value"; + metadata.insert(key, "value"); + + assert_eq!(metadata.get(key).unwrap().as_str(), Some(value)); + } + + #[test] + fn test_iter() { + let mut metadata = Metadata::default(); + metadata.insert("key1", json!("value1")); + metadata.insert("key2", json!("value2")); + + let mut iter = metadata.iter(); + assert_eq!(iter.next(), Some((&"key1".to_string(), &json!("value1")))); + assert_eq!(iter.next(), Some((&"key2".to_string(), &json!("value2")))); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_extend() { + let mut metadata = Metadata::default(); + metadata.extend(vec![("key1", json!("value1")), ("key2", json!("value2"))]); + + assert_eq!(metadata.get("key1"), Some(&json!("value1"))); + assert_eq!(metadata.get("key2"), Some(&json!("value2"))); + } + + #[test] + fn test_from_vec() { + let metadata = Metadata::from(vec![("key1", json!("value1")), ("key2", json!("value2"))]); + + assert_eq!(metadata.get("key1"), Some(&json!("value1"))); + assert_eq!(metadata.get("key2"), Some(&json!("value2"))); + } + + #[test] + fn test_into_values() { + let mut metadata = Metadata::default(); + metadata.insert("key1", json!("value1")); + metadata.insert("key2", json!("value2")); + + let values: Vec<_> = metadata.into_values().collect(); + assert_eq!(values, vec![json!("value1"), json!("value2")]); + } +} diff --git a/swiftide-core/src/node.rs b/swiftide-core/src/node.rs index 9c69efed..e567e3bc 100644 --- a/swiftide-core/src/node.rs +++ b/swiftide-core/src/node.rs @@ -18,7 +18,7 @@ //! individual units of data. It is particularly useful in scenarios where metadata and data chunks //! need to be processed together. use std::{ - collections::{BTreeMap, HashMap}, + collections::HashMap, fmt::Debug, hash::{Hash, Hasher}, path::PathBuf, @@ -27,6 +27,8 @@ use std::{ use itertools::Itertools; use serde::{Deserialize, Serialize}; +use crate::metadata::Metadata; + /// Represents a unit of data in the indexing process. /// /// `Node` encapsulates all necessary information for a single unit of data being processed @@ -43,7 +45,7 @@ pub struct Node { /// Optional vector representation of embedded data. pub vectors: Option>>, /// Metadata associated with the node. - pub metadata: BTreeMap, + pub metadata: Metadata, /// Mode of embedding data Chunk and Metadata pub embed_mode: EmbedMode, } @@ -99,7 +101,10 @@ impl Node { if self.embed_mode == EmbedMode::PerField || self.embed_mode == EmbedMode::Both { embeddables.push((EmbeddedField::Chunk, self.chunk.clone())); for (name, value) in &self.metadata { - embeddables.push((EmbeddedField::Metadata(name.clone()), value.clone())); + let value = value + .as_str() + .map_or_else(|| value.to_string(), ToString::to_string); + embeddables.push((EmbeddedField::Metadata(name.clone()), value)); } } @@ -119,7 +124,13 @@ impl Node { let metadata = self .metadata .iter() - .map(|(k, v)| format!("{k}: {v}")) + .map(|(k, v)| { + let v = v + .as_str() + .map_or_else(|| v.to_string(), ToString::to_string); + + format!("{k}: {v}") + }) .collect::>() .join("\n"); diff --git a/swiftide-indexing/src/transformers/embed.rs b/swiftide-indexing/src/transformers/embed.rs index 294ddf00..a191ba46 100644 --- a/swiftide-indexing/src/transformers/embed.rs +++ b/swiftide-indexing/src/transformers/embed.rs @@ -115,13 +115,11 @@ impl BatchableTransformer for Embed { #[cfg(test)] mod tests { - use swiftide_core::indexing::{EmbedMode, EmbeddedField, Node}; + use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, Node}; use swiftide_core::{BatchableTransformer, MockEmbeddingModel}; use super::Embed; - use std::collections::HashMap; - use futures_util::StreamExt; use mockall::predicate::*; use test_case::test_case; @@ -130,7 +128,7 @@ mod tests { struct TestData<'a> { pub embed_mode: EmbedMode, pub chunk: &'a str, - pub metadata: HashMap<&'a str, &'a str>, + pub metadata: Metadata, pub expected_embedables: Vec<&'a str>, pub expected_vectors: Vec<(EmbeddedField, Vec)>, } @@ -139,14 +137,14 @@ mod tests { TestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "chunk_1", - metadata: HashMap::from([("meta_1", "prompt_1")]), + metadata: Metadata::from([("meta_1", "prompt_1")]), expected_embedables: vec!["meta_1: prompt_1\nchunk_1"], expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])] }, TestData { embed_mode: EmbedMode::SingleWithMetadata, chunk: "chunk_2", - metadata: HashMap::from([("meta_2", "prompt_2")]), + metadata: Metadata::from([("meta_2", "prompt_2")]), expected_embedables: vec!["meta_2: prompt_2\nchunk_2"], expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])] } @@ -155,7 +153,7 @@ mod tests { TestData { embed_mode: EmbedMode::PerField, chunk: "chunk_1", - metadata: HashMap::from([("meta_1", "prompt 1")]), + metadata: Metadata::from([("meta_1", "prompt 1")]), expected_embedables: vec!["chunk_1", "prompt 1"], expected_vectors: vec![ (EmbeddedField::Chunk, vec![10f32]), @@ -165,7 +163,7 @@ mod tests { TestData { embed_mode: EmbedMode::PerField, chunk: "chunk_2", - metadata: HashMap::from([("meta_2", "prompt 2")]), + metadata: Metadata::from([("meta_2", "prompt 2")]), expected_embedables: vec!["chunk_2", "prompt 2"], expected_vectors: vec![ (EmbeddedField::Chunk, vec![20f32]), @@ -177,7 +175,7 @@ mod tests { TestData { embed_mode: EmbedMode::Both, chunk: "chunk_1", - metadata: HashMap::from([("meta_1", "prompt 1")]), + metadata: Metadata::from([("meta_1", "prompt 1")]), expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"], expected_vectors: vec![ (EmbeddedField::Combined, vec![10f32]), @@ -188,7 +186,7 @@ mod tests { TestData { embed_mode: EmbedMode::Both, chunk: "chunk_2", - metadata: HashMap::from([("meta_2", "prompt 2")]), + metadata: Metadata::from([("meta_2", "prompt 2")]), expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"], expected_vectors: vec![ (EmbeddedField::Combined, vec![20f32]), @@ -201,7 +199,7 @@ mod tests { TestData { embed_mode: EmbedMode::Both, chunk: "chunk_1", - metadata: HashMap::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]), + metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]), expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"], expected_vectors: vec![ (EmbeddedField::Combined, vec![10f32]), @@ -214,7 +212,7 @@ mod tests { TestData { embed_mode: EmbedMode::Both, chunk: "chunk_2", - metadata: HashMap::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]), + metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]), expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"], expected_vectors: vec![ (EmbeddedField::Combined, vec![20f32]), @@ -232,11 +230,7 @@ mod tests { .iter() .map(|data| Node { chunk: data.chunk.into(), - metadata: data - .metadata - .iter() - .map(|(k, v)| ((*k).to_string(), (*v).to_string())) - .collect(), + metadata: data.metadata.clone(), embed_mode: data.embed_mode, ..Default::default() }) diff --git a/swiftide-indexing/src/transformers/metadata_keywords.rs b/swiftide-indexing/src/transformers/metadata_keywords.rs index 87f8d957..966b7ec2 100644 --- a/swiftide-indexing/src/transformers/metadata_keywords.rs +++ b/swiftide-indexing/src/transformers/metadata_keywords.rs @@ -96,7 +96,7 @@ impl Transformer for MetadataKeywords { let prompt = self.prompt_template.to_prompt().with_node(&node); let response = self.client.prompt(prompt).await?; - node.metadata.insert(NAME.into(), response); + node.metadata.insert(NAME, response); Ok(node) } diff --git a/swiftide-indexing/src/transformers/metadata_qa_code.rs b/swiftide-indexing/src/transformers/metadata_qa_code.rs index 561d1fb1..7d01dbdd 100644 --- a/swiftide-indexing/src/transformers/metadata_qa_code.rs +++ b/swiftide-indexing/src/transformers/metadata_qa_code.rs @@ -101,7 +101,7 @@ impl Transformer for MetadataQACode { let response = self.client.prompt(prompt).await?; - node.metadata.insert(NAME.into(), response); + node.metadata.insert(NAME, response); Ok(node) } diff --git a/swiftide-indexing/src/transformers/metadata_qa_text.rs b/swiftide-indexing/src/transformers/metadata_qa_text.rs index 0d073c44..4b7538eb 100644 --- a/swiftide-indexing/src/transformers/metadata_qa_text.rs +++ b/swiftide-indexing/src/transformers/metadata_qa_text.rs @@ -104,7 +104,7 @@ impl Transformer for MetadataQAText { let response = self.client.prompt(prompt).await?; - node.metadata.insert(NAME.into(), response); + node.metadata.insert(NAME, response); Ok(node) } diff --git a/swiftide-indexing/src/transformers/metadata_refs_defs_code.rs b/swiftide-indexing/src/transformers/metadata_refs_defs_code.rs index 76220979..7d6aadc8 100644 --- a/swiftide-indexing/src/transformers/metadata_refs_defs_code.rs +++ b/swiftide-indexing/src/transformers/metadata_refs_defs_code.rs @@ -24,12 +24,12 @@ //! node = transformer.transform_node(node).await.unwrap(); //! //! assert_eq!( -//! node.metadata.get(NAME_REFERENCES), -//! Some(&"println".to_string()) +//! node.metadata.get(NAME_REFERENCES).unwrap().as_str().unwrap(), +//! "println" //! ); //! assert_eq!( -//! node.metadata.get(NAME_DEFINITIONS), -//! Some(&"main".to_string()) +//! node.metadata.get(NAME_DEFINITIONS).unwrap().as_str().unwrap(), +//! "main" //! ); //! # Ok(()) //! # } @@ -129,12 +129,12 @@ mod test { node = transformer.transform_node(node).await.unwrap(); assert_eq!( - node.metadata.get(NAME_REFERENCES), - Some(&"println".to_string()) + node.metadata.get(NAME_REFERENCES).unwrap().as_str(), + "println".into() ); assert_eq!( - node.metadata.get(NAME_DEFINITIONS), - Some(&"main".to_string()) + node.metadata.get(NAME_DEFINITIONS).unwrap().as_str(), + "main".into() ); } } diff --git a/swiftide-indexing/src/transformers/metadata_summary.rs b/swiftide-indexing/src/transformers/metadata_summary.rs index f66ac5c2..6c820fa3 100644 --- a/swiftide-indexing/src/transformers/metadata_summary.rs +++ b/swiftide-indexing/src/transformers/metadata_summary.rs @@ -97,7 +97,7 @@ impl Transformer for MetadataSummary { let response = self.client.prompt(prompt).await?; - node.metadata.insert(NAME.into(), response); + node.metadata.insert(NAME, response); Ok(node) } diff --git a/swiftide-indexing/src/transformers/metadata_title.rs b/swiftide-indexing/src/transformers/metadata_title.rs index 81186f3f..ba98efa1 100644 --- a/swiftide-indexing/src/transformers/metadata_title.rs +++ b/swiftide-indexing/src/transformers/metadata_title.rs @@ -97,7 +97,7 @@ impl Transformer for MetadataTitle { let response = self.client.prompt(prompt).await?; - node.metadata.insert(NAME.into(), response); + node.metadata.insert(NAME, response); Ok(node) } diff --git a/swiftide-integrations/src/qdrant/indexing_node.rs b/swiftide-integrations/src/qdrant/indexing_node.rs index 1faf457a..4205f446 100644 --- a/swiftide-integrations/src/qdrant/indexing_node.rs +++ b/swiftide-integrations/src/qdrant/indexing_node.rs @@ -4,7 +4,10 @@ //! data compatibility with Qdrant's required format. use anyhow::{bail, Result}; -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + string::ToString, +}; use qdrant_client::{ client::Payload, @@ -36,19 +39,24 @@ impl TryInto for NodeWithVectors { // Extend the metadata with additional information. node.metadata.extend([ - ("path".to_string(), node.path.to_string_lossy().to_string()), - ("content".to_string(), node.chunk), - ( - "last_updated_at".to_string(), - chrono::Utc::now().to_rfc3339(), - ), + ("path", node.path.to_string_lossy().to_string()), + ("content", node.chunk), + ("last_updated_at", chrono::Utc::now().to_rfc3339()), ]); // Create a payload compatible with Qdrant's API. let payload: Payload = node .metadata .iter() - .map(|(k, v)| (k.as_str(), Value::from(v.as_str()))) + .map(|(k, v)| { + ( + k.as_str(), + Value::from( + v.as_str() + .map_or_else(|| v.to_string(), ToString::to_string), + ), + ) + }) .collect::>() .into(); @@ -84,21 +92,21 @@ fn try_create_vectors( #[cfg(test)] mod tests { - use std::collections::{BTreeMap, HashMap, HashSet}; + use std::collections::{HashMap, HashSet}; use qdrant_client::qdrant::{ vectors::VectorsOptions, NamedVectors, PointId, PointStruct, Value, Vector, Vectors, }; + use swiftide_core::indexing::{EmbeddedField, Metadata, Node}; use test_case::test_case; use crate::qdrant::indexing_node::NodeWithVectors; - use swiftide_core::indexing::{EmbedMode, EmbeddedField, Node}; #[test_case( Node { id: Some(1), path: "/path".into(), chunk: "data".into(), vectors: Some(HashMap::from([(EmbeddedField::Chunk, vec![1.0])])), - metadata: BTreeMap::from([("m1".into(), "mv1".into())]), - embed_mode: EmbedMode::SingleWithMetadata + metadata: Metadata::from([("m1", "mv1")]), + embed_mode: swiftide_core::indexing::EmbedMode::SingleWithMetadata }, HashSet::from([EmbeddedField::Combined]), PointStruct { id: Some(PointId::from(6_516_159_902_038_153_111)), payload: HashMap::from([ @@ -115,8 +123,8 @@ mod tests { (EmbeddedField::Chunk, vec![1.0]), (EmbeddedField::Metadata("m1".into()), vec![2.0]) ])), - metadata: BTreeMap::from([("m1".into(), "mv1".into())]), - embed_mode: EmbedMode::PerField + metadata: Metadata::from([("m1", "mv1")]), + embed_mode: swiftide_core::indexing::EmbedMode::PerField }, HashSet::from([EmbeddedField::Chunk, EmbeddedField::Metadata("m1".into())]), PointStruct { id: Some(PointId::from(6_516_159_902_038_153_111)), payload: HashMap::from([ @@ -142,8 +150,8 @@ mod tests { (EmbeddedField::Metadata("m1".into()), vec![1.0]), (EmbeddedField::Metadata("m2".into()), vec![2.0]) ])), - metadata: BTreeMap::from([("m1".into(), "mv1".into()), ("m2".into(), "mv2".into())]), - embed_mode: EmbedMode::Both + metadata: Metadata::from([("m1", "mv1"), ("m2", "mv2")]), + embed_mode: swiftide_core::indexing::EmbedMode::Both }, HashSet::from([EmbeddedField::Combined]), PointStruct { id: Some(PointId::from(6_516_159_902_038_153_111)), payload: HashMap::from([ diff --git a/swiftide/tests/indexing_pipeline.rs b/swiftide/tests/indexing_pipeline.rs index 4f1d31de..0975a30b 100644 --- a/swiftide/tests/indexing_pipeline.rs +++ b/swiftide/tests/indexing_pipeline.rs @@ -54,6 +54,7 @@ async fn test_indexing_pipeline() { .then(transformers::MetadataQACode::new(openai_client.clone())) .filter_cached(integrations::redis::Redis::try_from_url(&redis_url, "prefix").unwrap()) .then_in_batch(1, transformers::Embed::new(openai_client.clone())) + .log_nodes() .then_store_with( integrations::qdrant::Qdrant::try_from_url(&qdrant_url) .unwrap() @@ -104,6 +105,7 @@ async fn test_indexing_pipeline() { let first = search_response.result.first().unwrap(); + dbg!(first); assert!(first .payload .get("path")