From 581d30b216037202d398b09b0f2eca4b51ba557c Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Wed, 6 Nov 2024 10:42:39 +0800 Subject: [PATCH] review3 --- lazyllm/tools/rag/chroma_store.py | 16 +++++++++------- lazyllm/tools/rag/doc_impl.py | 2 +- tests/basic_tests/test_document.py | 1 + tests/basic_tests/test_store.py | 9 ++++----- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index 9422890f..77770e86 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -10,6 +10,7 @@ from .default_index import DefaultIndex from .map_store import MapStore import pickle +import base64 # ---------------------------------------------------------------------------- # @@ -23,14 +24,14 @@ def __init__(self, dir: str, node_groups: List[str], embed: Dict[str, Callable], for group in node_groups } + self._map_store = MapStore(node_groups=node_groups, embed=embed) + self._load_store(embed_dim) + self._name2index = { 'default': DefaultIndex(embed, self._map_store), 'file_node_map': _FileNodeIndex(), } - self._map_store = MapStore(node_groups=node_groups, embed=embed) - self._load_store(embed_dim) - @override def update_nodes(self, nodes: List[DocNode]) -> None: self._map_store.update_nodes(nodes) @@ -132,13 +133,14 @@ def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dim: Dict[str chroma_metadata = results['metadatas'][i] parent = chroma_metadata['parent'] - fields = pickle.loads(chroma_metadata['fields']) if parent else None + fields = pickle.loads(base64.b64decode(chroma_metadata['fields'].encode('utf-8')))\ + if parent else None node = DocNode( uid=uid, text=results["documents"][i], group=chroma_metadata["group"], - embedding=pickle.loads(chroma_metadata['embedding']), + embedding=pickle.loads(base64.b64decode(chroma_metadata['embedding'].encode('utf-8'))), parent=parent, fields=fields, ) @@ -167,11 +169,11 @@ def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]: metadata = { "group": node.group, "parent": node.parent.uid if node.parent else "", - "embedding": pickle.dumps(node.embedding), + "embedding": base64.b64encode(pickle.dumps(node.embedding)).decode('utf-8'), } if node.parent: - metadata["fields"] = pickle.dumps(node.fields) + metadata["fields"] = base64.b64encode(pickle.dumps(node.fields)).decode('utf-8') return metadata diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 5252ea8c..30015d1b 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -219,7 +219,7 @@ def _add_files(self, input_files: List[str]): if len(input_files) == 0: return root_nodes = self._reader.load_data(input_files) - temp_store = self._create_store("map") + temp_store = self._create_store({"type": "map"}) temp_store.update_nodes(root_nodes) all_groups = self.store.all_groups() LOG.info(f"add_files: Trying to merge store with {all_groups}") diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index 11a33585..141df678 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -69,6 +69,7 @@ def test_add_files(self): assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 2 def test_delete_files(self): + self.doc_impl._lazy_init() self.doc_impl._delete_files(["dummy_file.txt"]) assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 0 diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 67d032ea..ab78bfe6 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -35,9 +35,8 @@ def setUp(self): } self.embed_dim = {"default": 3} - self.store = ChromadbStore(dir=self.store_dir, embed=self.mock_embed, embed_dim=self.embed_dim) - for group in self.node_groups: - self.store.activate_group(name=group, embed_keys=self.mock_embed.keys()) + self.store = ChromadbStore(dir=self.store_dir, node_groups=self.node_groups, + embed=self.mock_embed, embed_dim=self.embed_dim) self.store.update_nodes( [DocNode(uid="1", text="text1", group=LAZY_ROOT_NAME, parent=None)], @@ -75,7 +74,7 @@ def test_load_store(self): # Reset store and load from "persistent" storage self.store._map_store._group2docs = {group: {} for group in self.node_groups} - self.store._load_store() + self.store._load_store(self.embed_dim) nodes = self.store.get_nodes("group1") self.assertEqual(len(nodes), 2) @@ -93,7 +92,7 @@ def test_insert_dict_as_sparse_embedding(self): self.store.update_nodes([node1, node2]) results = self.store._peek_all_documents('group1') - nodes = self.store._build_nodes_from_chroma(results) + nodes = self.store._build_nodes_from_chroma(results, self.embed_dim) nodes_dict = { node.uid: node for node in nodes }