Skip to content

Commit

Permalink
review3
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Nov 6, 2024
1 parent f0cdd57 commit 581d30b
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 13 deletions.
16 changes: 9 additions & 7 deletions lazyllm/tools/rag/chroma_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .default_index import DefaultIndex
from .map_store import MapStore
import pickle
import base64

# ---------------------------------------------------------------------------- #

Expand All @@ -23,14 +24,14 @@ def __init__(self, dir: str, node_groups: List[str], embed: Dict[str, Callable],
for group in node_groups
}

self._map_store = MapStore(node_groups=node_groups, embed=embed)
self._load_store(embed_dim)

self._name2index = {
'default': DefaultIndex(embed, self._map_store),
'file_node_map': _FileNodeIndex(),
}

self._map_store = MapStore(node_groups=node_groups, embed=embed)
self._load_store(embed_dim)

@override
def update_nodes(self, nodes: List[DocNode]) -> None:
self._map_store.update_nodes(nodes)
Expand Down Expand Up @@ -132,13 +133,14 @@ def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dim: Dict[str
chroma_metadata = results['metadatas'][i]

parent = chroma_metadata['parent']
fields = pickle.loads(chroma_metadata['fields']) if parent else None
fields = pickle.loads(base64.b64decode(chroma_metadata['fields'].encode('utf-8')))\
if parent else None

node = DocNode(
uid=uid,
text=results["documents"][i],
group=chroma_metadata["group"],
embedding=pickle.loads(chroma_metadata['embedding']),
embedding=pickle.loads(base64.b64decode(chroma_metadata['embedding'].encode('utf-8'))),
parent=parent,
fields=fields,
)
Expand Down Expand Up @@ -167,11 +169,11 @@ def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]:
metadata = {
"group": node.group,
"parent": node.parent.uid if node.parent else "",
"embedding": pickle.dumps(node.embedding),
"embedding": base64.b64encode(pickle.dumps(node.embedding)).decode('utf-8'),
}

if node.parent:
metadata["fields"] = pickle.dumps(node.fields)
metadata["fields"] = base64.b64encode(pickle.dumps(node.fields)).decode('utf-8')

return metadata

Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _add_files(self, input_files: List[str]):
if len(input_files) == 0:
return
root_nodes = self._reader.load_data(input_files)
temp_store = self._create_store("map")
temp_store = self._create_store({"type": "map"})
temp_store.update_nodes(root_nodes)
all_groups = self.store.all_groups()
LOG.info(f"add_files: Trying to merge store with {all_groups}")
Expand Down
1 change: 1 addition & 0 deletions tests/basic_tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_add_files(self):
assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 2

def test_delete_files(self):
self.doc_impl._lazy_init()
self.doc_impl._delete_files(["dummy_file.txt"])
assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 0

Expand Down
9 changes: 4 additions & 5 deletions tests/basic_tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ def setUp(self):
}
self.embed_dim = {"default": 3}

self.store = ChromadbStore(dir=self.store_dir, embed=self.mock_embed, embed_dim=self.embed_dim)
for group in self.node_groups:
self.store.activate_group(name=group, embed_keys=self.mock_embed.keys())
self.store = ChromadbStore(dir=self.store_dir, node_groups=self.node_groups,
embed=self.mock_embed, embed_dim=self.embed_dim)

self.store.update_nodes(
[DocNode(uid="1", text="text1", group=LAZY_ROOT_NAME, parent=None)],
Expand Down Expand Up @@ -75,7 +74,7 @@ def test_load_store(self):

# Reset store and load from "persistent" storage
self.store._map_store._group2docs = {group: {} for group in self.node_groups}
self.store._load_store()
self.store._load_store(self.embed_dim)

nodes = self.store.get_nodes("group1")
self.assertEqual(len(nodes), 2)
Expand All @@ -93,7 +92,7 @@ def test_insert_dict_as_sparse_embedding(self):
self.store.update_nodes([node1, node2])

results = self.store._peek_all_documents('group1')
nodes = self.store._build_nodes_from_chroma(results)
nodes = self.store._build_nodes_from_chroma(results, self.embed_dim)
nodes_dict = {
node.uid: node for node in nodes
}
Expand Down

0 comments on commit 581d30b

Please sign in to comment.