From 1ba38a490d3037e77d18fdafd9d405f6f730feff Mon Sep 17 00:00:00 2001 From: GY <856454+ouonline@users.noreply.github.com> Date: Tue, 17 Dec 2024 19:22:01 +0800 Subject: [PATCH] bugfix: modify root node's global metadata (#392) --- lazyllm/tools/rag/chroma_store.py | 2 +- lazyllm/tools/rag/doc_node.py | 4 ++++ lazyllm/tools/rag/milvus_store.py | 2 +- requirements.full.txt | 2 +- requirements.txt | 2 +- scripts/check_requirements.py | 2 +- tests/basic_tests/test_store.py | 8 ++++++++ tests/requirements.txt | 2 +- 8 files changed, 18 insertions(+), 6 deletions(-) diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index 99772997..5574f6e5 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -170,7 +170,7 @@ def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]: "metadata": obj2str(node._metadata), } - if not node.parent: + if node.is_root_node: metadata["global_metadata"] = obj2str(node.global_metadata) return metadata diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 97d097af..f56a6450 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -83,6 +83,10 @@ def root_node(self) -> Optional["DocNode"]: root = root.parent return root or self + @property + def is_root_node(self) -> bool: + return (not self.parent) + @property def global_metadata(self) -> Dict[str, Any]: return self.root_node._global_metadata diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 62cb5565..12e90d99 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -280,7 +280,7 @@ def _deserialize_node_partial(self, result: Dict) -> DocNode: if k.startswith(self._embedding_key_prefix): doc.embedding[k[len(self._embedding_key_prefix):]] = v elif k.startswith(self._global_metadata_key_prefix): - if doc.parent: + if doc.is_root_node: doc._global_metadata[k[len(self._global_metadata_key_prefix):]] = v return doc diff --git a/requirements.full.txt b/requirements.full.txt index c3beac91..3fdf1b93 100644 --- a/requirements.full.txt +++ b/requirements.full.txt @@ -31,7 +31,7 @@ psutil pypdf pytest numpy==1.26.4 -pymilvus +pymilvus>=2.4.7, <2.5.0 async-timeout httpx<0.28.0 redis>=5.0.4 diff --git a/requirements.txt b/requirements.txt index 2544f32d..31a9f0bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,6 +31,6 @@ psutil pypdf pytest numpy==1.26.4 -pymilvus +pymilvus>=2.4.7, <2.5.0 async-timeout httpx<0.28.0 diff --git a/scripts/check_requirements.py b/scripts/check_requirements.py index 980cf05a..09187353 100644 --- a/scripts/check_requirements.py +++ b/scripts/check_requirements.py @@ -28,7 +28,7 @@ def parse_requirement(line): return None, None def compare_versions(version_spec, req_version): - if version_spec.startswith('^') and req_version == '*': + if version_spec.startswith('^'): return True return version_spec == req_version diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 0d1d2517..240bc7cf 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -192,6 +192,9 @@ def setUp(self): } self.global_metadata_desc = { 'comment': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_VARCHAR, max_size=65535, default_value=' '), + 'signature': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_VARCHAR, max_size=256, default_value=' '), + 'tags': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_ARRAY, + element_type=GlobalMetadataDesc.DTYPE_INT32, max_size=128, default_value=[]), } self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] @@ -278,6 +281,7 @@ def test_query_with_filter_non_exist(self): def test_reload(self): self.store.update_nodes([self.node1, self.node2, self.node3]) + # reload from storage del self.store self.store = MilvusStore(group_embed_keys=self.group_embed_keys, embed=self.mock_embed, embed_dims=self.embed_dims, global_metadata_desc=self.global_metadata_desc, @@ -291,6 +295,10 @@ def test_reload(self): for orig_node in orig_nodes: if node._uid == orig_node._uid: self.assertEqual(node.text, orig_node.text) + # builtin fields are not in orig node, so we can not use + # node.global_metadata == orig_node.global_metadata + for k, v in orig_node.global_metadata.items(): + self.assertEqual(node.global_metadata[k], v) break # XXX `array_contains_any` is not supported in local(aka lite) mode. skip this ut diff --git a/tests/requirements.txt b/tests/requirements.txt index 5a33e8db..30955a06 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,4 +3,4 @@ docx2txt olefile pytest-rerunfailures pytest-order -pymilvus +pymilvus>=2.4.7, <2.5.0