From fab1ae3877e245a158d3bf05b5a662f6d27d90eb Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 16 Dec 2024 17:09:42 +0800 Subject: [PATCH 1/4] bugfix: modify root node's global metadata --- lazyllm/tools/rag/chroma_store.py | 2 +- lazyllm/tools/rag/doc_node.py | 4 ++++ lazyllm/tools/rag/milvus_store.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index 99772997..5574f6e5 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -170,7 +170,7 @@ def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]: "metadata": obj2str(node._metadata), } - if not node.parent: + if node.is_root_node: metadata["global_metadata"] = obj2str(node.global_metadata) return metadata diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 97d097af..f56a6450 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -83,6 +83,10 @@ def root_node(self) -> Optional["DocNode"]: root = root.parent return root or self + @property + def is_root_node(self) -> bool: + return (not self.parent) + @property def global_metadata(self) -> Dict[str, Any]: return self.root_node._global_metadata diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 62cb5565..12e90d99 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -280,7 +280,7 @@ def _deserialize_node_partial(self, result: Dict) -> DocNode: if k.startswith(self._embedding_key_prefix): doc.embedding[k[len(self._embedding_key_prefix):]] = v elif k.startswith(self._global_metadata_key_prefix): - if doc.parent: + if doc.is_root_node: doc._global_metadata[k[len(self._global_metadata_key_prefix):]] = v return doc From f037e827add5f795b5a8c6f6d4d35ab387a1312a Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 16 Dec 2024 18:47:48 +0800 Subject: [PATCH 2/4] add ut --- tests/basic_tests/test_store.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 0d1d2517..9ae9c0c7 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -192,6 +192,9 @@ def setUp(self): } self.global_metadata_desc = { 'comment': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_VARCHAR, max_size=65535, default_value=' '), + 'signature': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_VARCHAR, max_size=256, default_value=' '), + 'tags': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_ARRAY, element_type=GlobalMetadataDesc.DTYPE_INT32, + max_size=128), } self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] @@ -278,6 +281,7 @@ def test_query_with_filter_non_exist(self): def test_reload(self): self.store.update_nodes([self.node1, self.node2, self.node3]) + # reload from storage del self.store self.store = MilvusStore(group_embed_keys=self.group_embed_keys, embed=self.mock_embed, embed_dims=self.embed_dims, global_metadata_desc=self.global_metadata_desc, @@ -291,6 +295,9 @@ def test_reload(self): for orig_node in orig_nodes: if node._uid == orig_node._uid: self.assertEqual(node.text, orig_node.text) + # builtin fields are not in orig node, so we can not use node.global_metadata == orig_node.global_metadata + for k, v in orig_node.global_metadata.items(): + self.assertEqual(node.global_metadata[k], v) break # XXX `array_contains_any` is not supported in local(aka lite) mode. skip this ut From 38da05324058d6fe90e14e21390c4c48286d5abd Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 17 Dec 2024 10:14:22 +0800 Subject: [PATCH 3/4] test fix --- requirements.full.txt | 2 +- requirements.txt | 2 +- scripts/check_requirements.py | 2 +- tests/basic_tests/test_store.py | 2 +- tests/requirements.txt | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/requirements.full.txt b/requirements.full.txt index c3beac91..3fdf1b93 100644 --- a/requirements.full.txt +++ b/requirements.full.txt @@ -31,7 +31,7 @@ psutil pypdf pytest numpy==1.26.4 -pymilvus +pymilvus>=2.4.7, <2.5.0 async-timeout httpx<0.28.0 redis>=5.0.4 diff --git a/requirements.txt b/requirements.txt index 2544f32d..31a9f0bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,6 +31,6 @@ psutil pypdf pytest numpy==1.26.4 -pymilvus +pymilvus>=2.4.7, <2.5.0 async-timeout httpx<0.28.0 diff --git a/scripts/check_requirements.py b/scripts/check_requirements.py index 980cf05a..09187353 100644 --- a/scripts/check_requirements.py +++ b/scripts/check_requirements.py @@ -28,7 +28,7 @@ def parse_requirement(line): return None, None def compare_versions(version_spec, req_version): - if version_spec.startswith('^') and req_version == '*': + if version_spec.startswith('^'): return True return version_spec == req_version diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 9ae9c0c7..f07f22d5 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -194,7 +194,7 @@ def setUp(self): 'comment': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_VARCHAR, max_size=65535, default_value=' '), 'signature': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_VARCHAR, max_size=256, default_value=' '), 'tags': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_ARRAY, element_type=GlobalMetadataDesc.DTYPE_INT32, - max_size=128), + max_size=128, default_value=[]), } self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] diff --git a/tests/requirements.txt b/tests/requirements.txt index 5a33e8db..30955a06 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,4 +3,4 @@ docx2txt olefile pytest-rerunfailures pytest-order -pymilvus +pymilvus>=2.4.7, <2.5.0 From f835a91b140d12d691d5262853b599be7b8784b9 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 17 Dec 2024 16:28:19 +0800 Subject: [PATCH 4/4] fix lint --- tests/basic_tests/test_store.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index f07f22d5..240bc7cf 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -193,8 +193,8 @@ def setUp(self): self.global_metadata_desc = { 'comment': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_VARCHAR, max_size=65535, default_value=' '), 'signature': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_VARCHAR, max_size=256, default_value=' '), - 'tags': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_ARRAY, element_type=GlobalMetadataDesc.DTYPE_INT32, - max_size=128, default_value=[]), + 'tags': GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_ARRAY, + element_type=GlobalMetadataDesc.DTYPE_INT32, max_size=128, default_value=[]), } self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] @@ -295,7 +295,8 @@ def test_reload(self): for orig_node in orig_nodes: if node._uid == orig_node._uid: self.assertEqual(node.text, orig_node.text) - # builtin fields are not in orig node, so we can not use node.global_metadata == orig_node.global_metadata + # builtin fields are not in orig node, so we can not use + # node.global_metadata == orig_node.global_metadata for k, v in orig_node.global_metadata.items(): self.assertEqual(node.global_metadata[k], v) break