From 3bd468ee84abcde902e1da082fb80b48daedb432 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 17 Jul 2024 15:27:09 +0800 Subject: [PATCH 01/16] feature: chromadb --- lazyllm/tools/rag/store.py | 257 ++++++++++++++++++++++++++++++------- 1 file changed, 209 insertions(+), 48 deletions(-) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 33075544..037e827c 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -1,6 +1,16 @@ +from abc import ABC, abstractmethod +import ast +import atexit from enum import Enum, auto import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional +import chromadb +from lazyllm import LOG, config +from chromadb.api.models.Collection import Collection + +LAZY_ROOT_NAME = "lazyllm_root" +config.add("rag_store", str, "map", "RAG_STORE") # "map", "chroma" +config.add("rag_persistent_path", str, "./lazyllm_chroma", "RAG_PERSISTENT_PATH") class MetadataMode(str, Enum): @@ -17,60 +27,79 @@ def __init__( text: Optional[str] = None, ntype: Optional[str] = None, embedding: Optional[List[float]] = None, - metadata: Optional[Dict[str, Any]] = None, - excluded_embed_metadata_keys: Optional[List[str]] = None, - excluded_llm_metadata_keys: Optional[List[str]] = None, parent: Optional["DocNode"] = None, + children: Optional[Dict[str, List]] = None, ) -> None: self.uid: str = uid if uid else str(uuid.uuid4()) self.text: Optional[str] = text self.ntype: Optional[str] = ntype - self.embedding: Optional[List[float]] = embedding - self.metadata: Dict[str, Any] = metadata if metadata is not None else {} + self.embedding: List[float] = embedding or [-1] + self._metadata: Dict[str, Any] = {} # Metadata keys that are excluded from text for the embed model. - self.excluded_embed_metadata_keys: List[str] = ( - excluded_embed_metadata_keys - if excluded_embed_metadata_keys is not None - else [] - ) + self._excluded_embed_metadata_keys: List[str] = [] # Metadata keys that are excluded from text for the LLM. - self.excluded_llm_metadata_keys: List[str] = ( - excluded_llm_metadata_keys if excluded_llm_metadata_keys is not None else [] - ) - # Relationships to other node. + self._excluded_llm_metadata_keys: List[str] = [] self.parent = parent - self.children: Dict[str, List["DocNode"]] = {} + self.children: Dict[str, List["DocNode"]] = children or {} + self.is_saved = False @property def root_node(self) -> Optional["DocNode"]: root = self.parent while root and root.parent: root = root.parent - return root + return root or self + + @property + def metadata(self) -> Dict: + return self.root_node._metadata + + @metadata.setter + def metadata(self, metadata: Dict) -> None: + self._metadata = metadata + + @property + def excluded_embed_metadata_keys(self) -> List: + return self.root_node._excluded_embed_metadata_keys + + @excluded_embed_metadata_keys.setter + def excluded_embed_metadata_keys(self, excluded_embed_metadata_keys: List) -> None: + self._excluded_embed_metadata_keys = excluded_embed_metadata_keys + + @property + def excluded_llm_metadata_keys(self) -> List: + return self.root_node._excluded_llm_metadata_keys + + @excluded_llm_metadata_keys.setter + def excluded_llm_metadata_keys(self, excluded_llm_metadata_keys: List) -> None: + self._excluded_llm_metadata_keys = excluded_llm_metadata_keys + + def get_children_str(self) -> str: + return str( + {key: [node.uid for node in nodes] for key, nodes in self.children.items()} + ) def __str__(self) -> str: - children_str = { - key: [node.uid for node in self.children[key]] - for key in self.children.keys() - } return ( f"DocNode(id: {self.uid}, ntype: {self.ntype}, text: {self.get_content()}) parent: " - f"{self.parent.uid if self.parent else None}, children: {children_str}" + f"{self.parent.uid if self.parent else None}, children: {self.get_children_str()} " + f"is_embed: {self.has_embedding()}" ) def __repr__(self) -> str: return str(self) - def get_embedding(self) -> List[float]: - if self.embedding is None: - raise ValueError("embedding not set.") - return self.embedding + def has_embedding(self) -> bool: + return self.embedding != [-1] + + def do_embedding(self, embed: Callable) -> None: + self.embedding = embed(self.text) + self.is_saved = False def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: metadata_str = self.get_metadata_str(mode=metadata_mode).strip() if not metadata_str: return self.text if self.text else "" - return f"{metadata_str}\n\n{self.text}".strip() def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: @@ -94,32 +123,164 @@ def get_text(self) -> str: return self.get_content(metadata_mode=MetadataMode.NONE) -# TODO: Have a common Base store class -class MapStore: - def __init__(self): - self.store: Dict[str, Dict[str, DocNode]] = {} - - def add_nodes(self, category: str, nodes: List[DocNode]): - if category not in self.store: - self.store[category] = {} +class BaseStore(ABC): + def __init__(self, node_groups: List[str]): + self._store: Dict[str, Dict[str, DocNode]] = { + group: {} for group in node_groups + } + def add_nodes(self, group: str, nodes: List[DocNode]): + if group not in self._store: + self._store[group] = {} for node in nodes: - self.store[category][node.uid] = node + self._store[group][node.uid] = node - def has_nodes(self, category: str) -> bool: - return category in self.store.keys() + def has_nodes(self, group: str) -> bool: + return len(self._store[group]) > 0 - def get_node(self, category: str, node_id: str) -> Optional[DocNode]: - return self.store.get(category, {}).get(node_id) + def get_node(self, group: str, node_id: str) -> Optional[DocNode]: + return self._store.get(group, {}).get(node_id) - def delete_node(self, category: str, node_id: str): - if category in self.store and node_id in self.store[category]: - del self.store[category][node_id] - # TODO: delete node's relationship + def traverse_nodes(self, group: str) -> List[DocNode]: + return list(self._store.get(group, {}).values()) - def traverse_nodes(self, category: str) -> List[DocNode]: - return list(self.store.get(category, {}).values()) + def get_node(self, group: str, node_id: str) -> Optional[DocNode]: + return self._store.get(group, {}).get(node_id) + @abstractmethod + def save_store(self) -> None: + raise NotImplementedError("Not implemented yet.") + + @abstractmethod + def try_load_store(self) -> None: + raise NotImplementedError("Not implemented yet.") + + +class MapStore(BaseStore): + def __init__(self, node_groups: List[str], *args, **kwargs): + super().__init__(node_groups, *args, **kwargs) + + def save_store(self) -> None: + pass + + def try_load_store(self) -> None: + pass + + +class ChromadbStore(BaseStore): + def __init__(self, node_groups: List[str], *args, **kwargs) -> None: + super().__init__(node_groups, *args, **kwargs) + self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) + LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") + self._collections: Dict[str, Collection] = { + group: self._db_client.get_or_create_collection(group) + for group in node_groups + } + self.try_load_store() + atexit.register(self.save_store) + + def try_load_store(self) -> None: + if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: + LOG.info("No persistent data found, skip the rebuilding phrase.") + return + + # Restore all nodes + for group in self._collections.keys(): + results = self._peek_all_documents(group) + nodes = self._build_node_from_chroma(results) + self.add_nodes(group, nodes) + + # Rebuild relationships + for group, nodes_dict in self._store.items(): + for node in nodes_dict.values(): + # Set parent + if node.parent: + parent_uid = node.parent + parent_node = self._find_node_by_uid(parent_uid) + node.parent = parent_node + + # Set children + for ntype, child_uids in node.children.items(): + child_nodes = [] + for child_uid in child_uids: + child_node = self._find_node_by_uid(child_uid) + child_nodes.append(child_node) + node.children[ntype] = child_nodes + LOG.debug(f"build {group} nodes from chromadb: {nodes_dict.values()}") + LOG.success("Successfully Built nodes from chromadb.") + + def save_store(self) -> None: + LOG.warning( + "Begin to save nodes to chromadb, please do not exit the program..." + ) + for group_name, collection in self._collections.items(): + nodes = self.traverse_nodes(group_name) + ids, embeddings, metadatas, documents = [], [], [], [] + insert_all = False + coll = collection._client._get_collection(collection.id) + if coll["dimension"] is None: + # empty collection, no insertion before + insert_all = True + elif coll["dimension"] == 1: + # collection with placeholder embedding + if nodes and nodes[0].has_embedding(): + # The current group has been embed, clear collection and reinsert + self._db_client.delete_collection(group_name) + collection = self._db_client.create_collection(group_name) + self._collections[group_name] = collection + insert_all = True + LOG.debug( + "Found placeholder embedding is replaced, reupsert from scratch" + ) + + for node in nodes: + if node.is_saved and not insert_all: + continue + ids.append(node.uid) + embeddings.append(node.embedding) + metadatas.append(self._make_chroma_metadata(node)) + documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) + if ids: + collection.upsert( + embeddings=embeddings, + ids=ids, + metadatas=metadatas, + documents=documents, + ) + LOG.debug(f"Saved {group_name} nodes {nodes} to chromadb") + LOG.success("All nodes saved, exit.") + + def _find_node_by_uid(self, uid: str) -> Optional[DocNode]: + for nodes_by_category in self._store.values(): + if uid in nodes_by_category: + return nodes_by_category[uid] + raise ValueError(f"UID {uid} not found in store.") + + def _build_node_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: + nodes: List[DocNode] = [] + for i, uid in enumerate(results["ids"]): + chroma_metadata = results["metadatas"][i] + node = DocNode( + uid=uid, + text=results["documents"][i], + ntype=chroma_metadata["ntype"], + embedding=results["embeddings"][i], + parent=chroma_metadata["parent"], + children=ast.literal_eval(chroma_metadata["children"]), + ) + node.is_saved = True + nodes.append(node) + return nodes + + def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]: + metadata = { + "ntype": node.ntype, + "parent": node.parent.uid if node.parent else "", + "children": node.get_children_str(), + } + return metadata -class ChromadbStore: - pass + def _peek_all_documents(self, group: str) -> Dict[str, List]: + assert group in self._collections, f"group {group} not found." + collection = self._collections[group] + return collection.peek(collection.count()) From 4de0fda990c03dc04f317dae864ccfd272275fb9 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 17 Jul 2024 15:28:10 +0800 Subject: [PATCH 02/16] feature: chromadb --- lazyllm/tools/rag/data_loaders.py | 6 ++-- lazyllm/tools/rag/doc_impl.py | 59 ++++++++++++++++++++----------- lazyllm/tools/rag/index.py | 18 +++++++--- lazyllm/tools/rag/transform.py | 4 --- 4 files changed, 55 insertions(+), 32 deletions(-) diff --git a/lazyllm/tools/rag/data_loaders.py b/lazyllm/tools/rag/data_loaders.py index 88125c29..df2cd3c1 100644 --- a/lazyllm/tools/rag/data_loaders.py +++ b/lazyllm/tools/rag/data_loaders.py @@ -18,10 +18,10 @@ def load_data(self, ntype: str = "root") -> List["DocNode"]: node = DocNode( text=doc.text, ntype=ntype, - metadata=doc.metadata, - excluded_embed_metadata_keys=doc.excluded_embed_metadata_keys, - excluded_llm_metadata_keys=doc.excluded_llm_metadata_keys, ) + node.metadata = doc.metadata + node.excluded_embed_metadata_keys = doc.excluded_embed_metadata_keys + node.excluded_llm_metadata_keys = doc.excluded_llm_metadata_keys nodes.append(node) if not nodes: LOG.warning( diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 897e379f..96a7d73c 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -1,21 +1,48 @@ -from functools import partial +import ast +from functools import partial, wraps from typing import Dict, List, Optional, Set -from lazyllm import ModuleBase, LOG +from lazyllm import ModuleBase, LOG, config, once_flag, call_once from lazyllm.common import LazyLlmRequest from .transform import FuncNodeTransform, SentenceSplitter -from .store import MapStore, DocNode +from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME from .data_loaders import DirectoryReader from .index import DefaultIndex +def embed_wrapper(func): + if not func: + return None + + @wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + return ast.literal_eval(result) + + return wrapper + + class DocImplV2: def __init__(self, embed, doc_files=Optional[List[str]], **kwargs): super().__init__() self.directory_reader = DirectoryReader(input_files=doc_files) - self.node_groups: Dict[str, Dict] = {} + self.node_groups: Dict[str, Dict] = {LAZY_ROOT_NAME: {}} self.create_node_group_default() - self.store = MapStore() - self.index = DefaultIndex(embed) + self.embed = embed_wrapper(embed) + self.init_flag = once_flag() + + def _lazy_init(self) -> None: + rag_store = config["rag_store"] + if rag_store == "map": + self.store = MapStore(node_groups=self.node_groups.keys()) + elif rag_store == "chroma": + self.store = ChromadbStore(node_groups=self.node_groups.keys()) + else: + raise NotImplementedError(f"Not implemented store type for {rag_store}") + self.index = DefaultIndex(self.embed) + if not self.store.has_nodes(LAZY_ROOT_NAME): + docs = self.directory_reader.load_data() + self.store.add_nodes(LAZY_ROOT_NAME, docs) + LOG.debug(f"building {LAZY_ROOT_NAME} nodes: {docs}") def create_node_group_default(self): self.create_node_group( @@ -38,7 +65,7 @@ def create_node_group_default(self): ) def create_node_group( - self, name, transform, parent="_lazyllm_root", **kwargs + self, name, transform, parent=LAZY_ROOT_NAME, **kwargs ) -> None: if name in self.node_groups: LOG.warning(f"Duplicate group name: {name}") @@ -67,25 +94,17 @@ def _dynamic_create_nodes(self, group_name) -> None: if self.store.has_nodes(group_name): return transform = self._get_transform(group_name) - parent_name = node_group["parent_name"] - self._dynamic_create_nodes(parent_name) - - parent_nodes = self.store.traverse_nodes(parent_name) - - sub_nodes = transform(parent_nodes, group_name) - self.store.add_nodes(group_name, sub_nodes) - LOG.debug(f"building {group_name} nodes: {sub_nodes}") + parent_nodes = self._get_nodes(node_group["parent_name"]) + nodes = transform(parent_nodes, group_name) + self.store.add_nodes(group_name, nodes) + LOG.debug(f"building {group_name} nodes: {nodes}") def _get_nodes(self, group_name: str) -> List[DocNode]: - # lazy load files, if group isn't set, create the group - if not self.store.has_nodes("_lazyllm_root"): - docs = self.directory_reader.load_data() - self.store.add_nodes("_lazyllm_root", docs) - LOG.debug(f"building _lazyllm_root nodes: {docs}") self._dynamic_create_nodes(group_name) return self.store.traverse_nodes(group_name) def retrieve(self, query, group_name, similarity, index, topk, similarity_kws): + call_once(self.init_flag, self._lazy_init) if index: assert index == "default", "we only support default index currently" if isinstance(query, LazyLlmRequest): diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index 42262013..8d129d84 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -1,4 +1,5 @@ -import ast +from typing import List +from .store import DocNode import numpy as np @@ -18,16 +19,23 @@ def decorator(f): return decorator(func) if func else decorator - def query(self, query, nodes, similarity_name, topk=None, **kwargs): + def query( + self, + query: str, + nodes: List[DocNode], + similarity_name: str, + topk: int, + **kwargs, + ) -> List[DocNode]: similarity_func, mode, descend = self.registered_similarity[similarity_name] if mode == "embedding": assert self.embed, "Chosen similarity needs embed model." assert len(query) > 0, "Query should not be empty." - query_embedding = ast.literal_eval(self.embed(query)) + query_embedding = self.embed(query) for node in nodes: - if not node.embedding: - node.embedding = ast.literal_eval(self.embed(node.text)) + if not node.has_embedding(): + node.do_embedding(self.embed) similarities = [ (node, similarity_func(query_embedding, node.embedding, **kwargs)) for node in nodes diff --git a/lazyllm/tools/rag/transform.py b/lazyllm/tools/rag/transform.py index 0dbd1d30..be35d001 100644 --- a/lazyllm/tools/rag/transform.py +++ b/lazyllm/tools/rag/transform.py @@ -23,10 +23,6 @@ def build_nodes_from_splits( node = DocNode( text=text_chunk, ntype=node_group, - embedding=doc.embedding, - metadata=doc.metadata, - excluded_embed_metadata_keys=doc.excluded_embed_metadata_keys, - excluded_llm_metadata_keys=doc.excluded_llm_metadata_keys, parent=doc, ) nodes.append(node) From eb0992378760d8e410af7efb5ba1d024dd191b56 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 17 Jul 2024 15:28:19 +0800 Subject: [PATCH 03/16] update unit test --- tests/basic_tests/test_doc_node.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/basic_tests/test_doc_node.py b/tests/basic_tests/test_doc_node.py index 397e3890..e9873694 100644 --- a/tests/basic_tests/test_doc_node.py +++ b/tests/basic_tests/test_doc_node.py @@ -9,11 +9,11 @@ def setup_method(self): self.embedding = [0.1, 0.2, 0.3] self.node = DocNode( text=self.text, - metadata=self.metadata, embedding=self.embedding, - excluded_embed_metadata_keys=["author"], - excluded_llm_metadata_keys=["date"], ) + self.node.metadata = self.metadata + self.excluded_embed_metadata_keys = ["author"] + self.excluded_llm_metadata_keys = ["date"] def test_node_creation(self): """Test the creation of a DocNode.""" @@ -50,11 +50,6 @@ def test_get_metadata_str(self): metadata_str_none = self.node.get_metadata_str(mode=MetadataMode.NONE) assert metadata_str_none == "" - def test_get_embedding(self): - """Test the get_embedding method.""" - embedding = self.node.get_embedding() - assert embedding == self.embedding - def test_root_node(self): """Test the root_node property.""" child_node = DocNode(text="Child node", parent=self.node) From 7e68c4f88ca9ace5eb9e38958400c79936e87450 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 17 Jul 2024 15:32:15 +0800 Subject: [PATCH 04/16] fix: flake8 lint --- lazyllm/tools/rag/store.py | 3 --- tests/basic_tests/test_doc_node.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 037e827c..9500ac45 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -144,9 +144,6 @@ def get_node(self, group: str, node_id: str) -> Optional[DocNode]: def traverse_nodes(self, group: str) -> List[DocNode]: return list(self._store.get(group, {}).values()) - def get_node(self, group: str, node_id: str) -> Optional[DocNode]: - return self._store.get(group, {}).get(node_id) - @abstractmethod def save_store(self) -> None: raise NotImplementedError("Not implemented yet.") diff --git a/tests/basic_tests/test_doc_node.py b/tests/basic_tests/test_doc_node.py index e9873694..956085a5 100644 --- a/tests/basic_tests/test_doc_node.py +++ b/tests/basic_tests/test_doc_node.py @@ -13,7 +13,7 @@ def setup_method(self): ) self.node.metadata = self.metadata self.excluded_embed_metadata_keys = ["author"] - self.excluded_llm_metadata_keys = ["date"] + self.excluded_llm_metadata_keys = ["date"] def test_node_creation(self): """Test the creation of a DocNode.""" From 6b6e4ad762af8669308db54d289d81a2a11fcfb3 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 17 Jul 2024 15:34:58 +0800 Subject: [PATCH 05/16] fix basic test --- tests/basic_tests/test_doc_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/basic_tests/test_doc_node.py b/tests/basic_tests/test_doc_node.py index 956085a5..7029c0c7 100644 --- a/tests/basic_tests/test_doc_node.py +++ b/tests/basic_tests/test_doc_node.py @@ -12,8 +12,8 @@ def setup_method(self): embedding=self.embedding, ) self.node.metadata = self.metadata - self.excluded_embed_metadata_keys = ["author"] - self.excluded_llm_metadata_keys = ["date"] + self.node.excluded_embed_metadata_keys = ["author"] + self.node.excluded_llm_metadata_keys = ["date"] def test_node_creation(self): """Test the creation of a DocNode.""" From 70c9a21d789bb54acb441825cb7f96357c929ddd Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Thu, 18 Jul 2024 14:38:31 +0800 Subject: [PATCH 06/16] feature: runtime save nodes without parents --- lazyllm/tools/rag/store.py | 97 +++++++++++++------------------------- 1 file changed, 33 insertions(+), 64 deletions(-) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 9500ac45..9931e80e 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -import ast -import atexit +from collections import defaultdict from enum import Enum, auto import uuid from typing import Any, Callable, Dict, List, Optional @@ -28,19 +27,18 @@ def __init__( ntype: Optional[str] = None, embedding: Optional[List[float]] = None, parent: Optional["DocNode"] = None, - children: Optional[Dict[str, List]] = None, ) -> None: self.uid: str = uid if uid else str(uuid.uuid4()) self.text: Optional[str] = text self.ntype: Optional[str] = ntype - self.embedding: List[float] = embedding or [-1] + self.embedding: Optional[List[float]] = embedding or None self._metadata: Dict[str, Any] = {} # Metadata keys that are excluded from text for the embed model. self._excluded_embed_metadata_keys: List[str] = [] # Metadata keys that are excluded from text for the LLM. self._excluded_llm_metadata_keys: List[str] = [] self.parent = parent - self.children: Dict[str, List["DocNode"]] = children or {} + self.children: Dict[str, List["DocNode"]] = defaultdict(list) self.is_saved = False @property @@ -90,11 +88,10 @@ def __repr__(self) -> str: return str(self) def has_embedding(self) -> bool: - return self.embedding != [-1] + return self.embedding and len(self.embedding) > 0 def do_embedding(self, embed: Callable) -> None: self.embedding = embed(self.text) - self.is_saved = False def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: metadata_str = self.get_metadata_str(mode=metadata_mode).strip() @@ -134,6 +131,7 @@ def add_nodes(self, group: str, nodes: List[DocNode]): self._store[group] = {} for node in nodes: self._store[group][node.uid] = node + self.save_nodes(group) def has_nodes(self, group: str) -> bool: return len(self._store[group]) > 0 @@ -145,7 +143,7 @@ def traverse_nodes(self, group: str) -> List[DocNode]: return list(self._store.get(group, {}).values()) @abstractmethod - def save_store(self) -> None: + def save_nodes(self, group: str) -> None: raise NotImplementedError("Not implemented yet.") @abstractmethod @@ -157,7 +155,7 @@ class MapStore(BaseStore): def __init__(self, node_groups: List[str], *args, **kwargs): super().__init__(node_groups, *args, **kwargs) - def save_store(self) -> None: + def save_nodes(self) -> None: pass def try_load_store(self) -> None: @@ -165,7 +163,7 @@ def try_load_store(self) -> None: class ChromadbStore(BaseStore): - def __init__(self, node_groups: List[str], *args, **kwargs) -> None: + def __init__(self, node_groups: List[str], embed: Callable, *args, **kwargs) -> None: super().__init__(node_groups, *args, **kwargs) self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") @@ -173,8 +171,8 @@ def __init__(self, node_groups: List[str], *args, **kwargs) -> None: group: self._db_client.get_or_create_collection(group) for group in node_groups } + self.embed = embed self.try_load_store() - atexit.register(self.save_store) def try_load_store(self) -> None: if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: @@ -184,68 +182,41 @@ def try_load_store(self) -> None: # Restore all nodes for group in self._collections.keys(): results = self._peek_all_documents(group) - nodes = self._build_node_from_chroma(results) + nodes = self._build_nodes_from_chroma(results) self.add_nodes(group, nodes) # Rebuild relationships for group, nodes_dict in self._store.items(): for node in nodes_dict.values(): - # Set parent if node.parent: parent_uid = node.parent parent_node = self._find_node_by_uid(parent_uid) node.parent = parent_node - - # Set children - for ntype, child_uids in node.children.items(): - child_nodes = [] - for child_uid in child_uids: - child_node = self._find_node_by_uid(child_uid) - child_nodes.append(child_node) - node.children[ntype] = child_nodes + parent_node.children[node.ntype].append(node) LOG.debug(f"build {group} nodes from chromadb: {nodes_dict.values()}") LOG.success("Successfully Built nodes from chromadb.") - def save_store(self) -> None: - LOG.warning( - "Begin to save nodes to chromadb, please do not exit the program..." - ) - for group_name, collection in self._collections.items(): - nodes = self.traverse_nodes(group_name) - ids, embeddings, metadatas, documents = [], [], [], [] - insert_all = False - coll = collection._client._get_collection(collection.id) - if coll["dimension"] is None: - # empty collection, no insertion before - insert_all = True - elif coll["dimension"] == 1: - # collection with placeholder embedding - if nodes and nodes[0].has_embedding(): - # The current group has been embed, clear collection and reinsert - self._db_client.delete_collection(group_name) - collection = self._db_client.create_collection(group_name) - self._collections[group_name] = collection - insert_all = True - LOG.debug( - "Found placeholder embedding is replaced, reupsert from scratch" - ) - - for node in nodes: - if node.is_saved and not insert_all: - continue - ids.append(node.uid) - embeddings.append(node.embedding) - metadatas.append(self._make_chroma_metadata(node)) - documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) - if ids: - collection.upsert( - embeddings=embeddings, - ids=ids, - metadatas=metadatas, - documents=documents, - ) - LOG.debug(f"Saved {group_name} nodes {nodes} to chromadb") - LOG.success("All nodes saved, exit.") + def save_nodes(self, group: str) -> None: + nodes = self.traverse_nodes(group) + ids, embeddings, metadatas, documents = [], [], [], [] + collection = self._collections.get(group) + for node in nodes: + if node.is_saved: + continue + if not node.has_embedding(): + node.do_embedding(self.embed) + ids.append(node.uid) + embeddings.append(node.embedding) + metadatas.append(self._make_chroma_metadata(node)) + documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) + if ids: + collection.upsert( + embeddings=embeddings, + ids=ids, + metadatas=metadatas, + documents=documents, + ) + LOG.debug(f"Saved {group} nodes {ids} to chromadb") def _find_node_by_uid(self, uid: str) -> Optional[DocNode]: for nodes_by_category in self._store.values(): @@ -253,7 +224,7 @@ def _find_node_by_uid(self, uid: str) -> Optional[DocNode]: return nodes_by_category[uid] raise ValueError(f"UID {uid} not found in store.") - def _build_node_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: + def _build_nodes_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: nodes: List[DocNode] = [] for i, uid in enumerate(results["ids"]): chroma_metadata = results["metadatas"][i] @@ -263,7 +234,6 @@ def _build_node_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: ntype=chroma_metadata["ntype"], embedding=results["embeddings"][i], parent=chroma_metadata["parent"], - children=ast.literal_eval(chroma_metadata["children"]), ) node.is_saved = True nodes.append(node) @@ -273,7 +243,6 @@ def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]: metadata = { "ntype": node.ntype, "parent": node.parent.uid if node.parent else "", - "children": node.get_children_str(), } return metadata From 83825babb04ad6516c6f3238917d1dd00699d6f3 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Thu, 18 Jul 2024 14:38:41 +0800 Subject: [PATCH 07/16] adding embed for chromadb --- lazyllm/tools/rag/doc_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 96a7d73c..3b0ed8fa 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -35,7 +35,7 @@ def _lazy_init(self) -> None: if rag_store == "map": self.store = MapStore(node_groups=self.node_groups.keys()) elif rag_store == "chroma": - self.store = ChromadbStore(node_groups=self.node_groups.keys()) + self.store = ChromadbStore(node_groups=self.node_groups.keys(), embed=self.embed) else: raise NotImplementedError(f"Not implemented store type for {rag_store}") self.index = DefaultIndex(self.embed) From 6cd6a2aad52f41055535afb6182825c2a2bb0abd Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Thu, 18 Jul 2024 15:02:31 +0800 Subject: [PATCH 08/16] fix bugs and lint --- lazyllm/tools/rag/doc_impl.py | 4 +++- lazyllm/tools/rag/store.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 3b0ed8fa..83a66625 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -35,7 +35,9 @@ def _lazy_init(self) -> None: if rag_store == "map": self.store = MapStore(node_groups=self.node_groups.keys()) elif rag_store == "chroma": - self.store = ChromadbStore(node_groups=self.node_groups.keys(), embed=self.embed) + self.store = ChromadbStore( + node_groups=self.node_groups.keys(), embed=self.embed + ) else: raise NotImplementedError(f"Not implemented store type for {rag_store}") self.index = DefaultIndex(self.embed) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 9931e80e..3cfe314b 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -155,7 +155,7 @@ class MapStore(BaseStore): def __init__(self, node_groups: List[str], *args, **kwargs): super().__init__(node_groups, *args, **kwargs) - def save_nodes(self) -> None: + def save_nodes(self, group: str) -> None: pass def try_load_store(self) -> None: @@ -163,7 +163,9 @@ def try_load_store(self) -> None: class ChromadbStore(BaseStore): - def __init__(self, node_groups: List[str], embed: Callable, *args, **kwargs) -> None: + def __init__( + self, node_groups: List[str], embed: Callable, *args, **kwargs + ) -> None: super().__init__(node_groups, *args, **kwargs) self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") From 4babd7e238351804884f7ac059fd161fe0fc00db Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Thu, 18 Jul 2024 15:02:38 +0800 Subject: [PATCH 09/16] add unit test for store --- tests/basic_tests/test_store.py | 55 +++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 tests/basic_tests/test_store.py diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py new file mode 100644 index 00000000..8b336d6f --- /dev/null +++ b/tests/basic_tests/test_store.py @@ -0,0 +1,55 @@ +import unittest +from unittest.mock import MagicMock +from lazyllm.tools.rag.store import DocNode, ChromadbStore, LAZY_ROOT_NAME + + +# Test class for ChromadbStore +class TestChromadbStore(unittest.TestCase): + def setUp(self): + self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] + self.embed = MagicMock(side_effect=lambda text: [0.1, 0.2, 0.3]) + self.store = ChromadbStore(self.node_groups, self.embed) + self.store.add_nodes( + LAZY_ROOT_NAME, + [DocNode(uid="1", text="text1", ntype="group1", parent=None)], + ) + + def test_initialization(self): + self.assertEqual(set(self.store._collections.keys()), set(self.node_groups)) + + def test_add_and_traverse_nodes(self): + node1 = DocNode(uid="1", text="text1", ntype="type1") + node2 = DocNode(uid="2", text="text2", ntype="type2") + self.store.add_nodes("group1", [node1, node2]) + nodes = self.store.traverse_nodes("group1") + self.assertEqual(nodes, [node1, node2]) + + def test_save_nodes(self): + node1 = DocNode(uid="1", text="text1", ntype="type1") + node2 = DocNode(uid="2", text="text2", ntype="type2") + self.store.add_nodes("group1", [node1, node2]) + self.store.save_nodes("group1") + collection = self.store._collections["group1"] + self.assertEqual(collection.peek(collection.count())["ids"], ["1", "2"]) + self.assertTrue(node1.has_embedding()) + self.assertTrue(node2.has_embedding()) + + def test_try_load_store(self): + # Set up initial data to be loaded + node1 = DocNode(uid="1", text="text1", ntype="group1", parent=None) + node2 = DocNode(uid="2", text="text2", ntype="group1", parent=node1) + self.store.add_nodes("group1", [node1, node2]) + + # Reset store and load from "persistent" storage + self.store._store = {group: {} for group in self.node_groups} + self.store.try_load_store() + + nodes = self.store.traverse_nodes("group1") + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[0].uid, "1") + self.assertEqual(nodes[1].uid, "2") + self.assertEqual(nodes[1].parent.uid, "1") + + +if __name__ == "__main__": + unittest.main() From dec3d6675627334b2eb59bde4b658ee70ec4b72e Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 19 Jul 2024 16:11:13 +0800 Subject: [PATCH 10/16] fix: root name error --- lazyllm/tools/rag/data_loaders.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lazyllm/tools/rag/data_loaders.py b/lazyllm/tools/rag/data_loaders.py index df2cd3c1..6c54171d 100644 --- a/lazyllm/tools/rag/data_loaders.py +++ b/lazyllm/tools/rag/data_loaders.py @@ -1,5 +1,5 @@ from typing import List -from .store import DocNode +from .store import DocNode, LAZY_ROOT_NAME from lazyllm import LOG @@ -7,7 +7,7 @@ class DirectoryReader: def __init__(self, input_files: List[str]): self.input_files = input_files - def load_data(self, ntype: str = "root") -> List["DocNode"]: + def load_data(self, group: str = LAZY_ROOT_NAME) -> List["DocNode"]: from llama_index.core import SimpleDirectoryReader llama_index_docs = SimpleDirectoryReader( @@ -17,7 +17,7 @@ def load_data(self, ntype: str = "root") -> List["DocNode"]: for doc in llama_index_docs: node = DocNode( text=doc.text, - ntype=ntype, + group=group, ) node.metadata = doc.metadata node.excluded_embed_metadata_keys = doc.excluded_embed_metadata_keys From e3727e0d3b6ad3d09982132711e296384179f77e Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 19 Jul 2024 16:13:46 +0800 Subject: [PATCH 11/16] rename ntype to group --- lazyllm/tools/rag/doc_impl.py | 36 ++++++++++++++++----------------- lazyllm/tools/rag/transform.py | 2 +- tests/basic_tests/test_store.py | 14 ++++++------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 83a66625..eeb0f6f2 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -40,7 +40,7 @@ def _lazy_init(self) -> None: ) else: raise NotImplementedError(f"Not implemented store type for {rag_store}") - self.index = DefaultIndex(self.embed) + self.index = DefaultIndex(self.embed, self.store) if not self.store.has_nodes(LAZY_ROOT_NAME): docs = self.directory_reader.load_data() self.store.add_nodes(LAZY_ROOT_NAME, docs) @@ -115,10 +115,10 @@ def retrieve(self, query, group_name, similarity, index, topk, similarity_kws): nodes = self._get_nodes(group_name) return self.index.query(query, nodes, similarity, topk, **similarity_kws) - def _find_parent(self, nodes: List[DocNode], name: str) -> List[DocNode]: + def _find_parent(self, nodes: List[DocNode], group: str) -> List[DocNode]: def recurse_parents(node: DocNode, visited: Set[DocNode]) -> None: if node.parent: - if node.parent.ntype == name: + if node.parent.group == group: visited.add(node.parent) recurse_parents(node.parent, visited) @@ -127,18 +127,18 @@ def recurse_parents(node: DocNode, visited: Set[DocNode]) -> None: recurse_parents(node, result) if not result: LOG.warning( - f"We can not find any nodes for name `{name}`, please check your input" + f"We can not find any nodes for group `{group}`, please check your input" ) - LOG.debug(f"Found parent node for {name}: {result}") + LOG.debug(f"Found parent node for {group}: {result}") return list(result) - def find_parent(self, name: str) -> List[DocNode]: - return partial(self._find_parent, name=name) + def find_parent(self, group: str) -> List[DocNode]: + return partial(self._find_parent, group=group) - def _find_children(self, nodes: List[DocNode], name: str) -> List[DocNode]: + def _find_children(self, nodes: List[DocNode], group: str) -> List[DocNode]: def recurse_children(node: DocNode, visited: Set[DocNode]) -> bool: - if name in node.children: - visited.update(node.children[name]) + if group in node.children: + visited.update(node.children[group]) return True found_in_any_child = False @@ -155,11 +155,11 @@ def recurse_children(node: DocNode, visited: Set[DocNode]) -> bool: result = set() # case when user hasn't used the group before. - _ = self._get_nodes(name) + _ = self._get_nodes(group) for node in nodes: - if name in node.children: - result.update(node.children[name]) + if group in node.children: + result.update(node.children[group]) else: LOG.log_once( f"Fetching children that are not in direct relationship might be slower. " @@ -170,21 +170,21 @@ def recurse_children(node: DocNode, visited: Set[DocNode]) -> bool: # Note: the input nodes are the same type if not recurse_children(node, result): LOG.warning( - f"Node {node} and its children do not contain any nodes with the name `{name}`. " + f"Node {node} and its children do not contain any nodes with the group `{group}`. " "Skipping further search in this branch." ) break if not result: LOG.warning( - f"We cannot find any nodes for name `{name}`, please check your input." + f"We cannot find any nodes for group `{group}`, please check your input." ) - LOG.debug(f"Found children nodes for {name}: {result}") + LOG.debug(f"Found children nodes for {group}: {result}") return list(result) - def find_children(self, name: str) -> List[DocNode]: - return partial(self._find_children, name=name) + def find_children(self, group: str) -> List[DocNode]: + return partial(self._find_children, group=group) class RetrieverV2(ModuleBase): diff --git a/lazyllm/tools/rag/transform.py b/lazyllm/tools/rag/transform.py index be35d001..9bf64422 100644 --- a/lazyllm/tools/rag/transform.py +++ b/lazyllm/tools/rag/transform.py @@ -22,7 +22,7 @@ def build_nodes_from_splits( continue node = DocNode( text=text_chunk, - ntype=node_group, + group=node_group, parent=doc, ) nodes.append(node) diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 8b336d6f..bfc8d1a2 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -11,22 +11,22 @@ def setUp(self): self.store = ChromadbStore(self.node_groups, self.embed) self.store.add_nodes( LAZY_ROOT_NAME, - [DocNode(uid="1", text="text1", ntype="group1", parent=None)], + [DocNode(uid="1", text="text1", group="group1", parent=None)], ) def test_initialization(self): self.assertEqual(set(self.store._collections.keys()), set(self.node_groups)) def test_add_and_traverse_nodes(self): - node1 = DocNode(uid="1", text="text1", ntype="type1") - node2 = DocNode(uid="2", text="text2", ntype="type2") + node1 = DocNode(uid="1", text="text1", group="type1") + node2 = DocNode(uid="2", text="text2", group="type2") self.store.add_nodes("group1", [node1, node2]) nodes = self.store.traverse_nodes("group1") self.assertEqual(nodes, [node1, node2]) def test_save_nodes(self): - node1 = DocNode(uid="1", text="text1", ntype="type1") - node2 = DocNode(uid="2", text="text2", ntype="type2") + node1 = DocNode(uid="1", text="text1", group="type1") + node2 = DocNode(uid="2", text="text2", group="type2") self.store.add_nodes("group1", [node1, node2]) self.store.save_nodes("group1") collection = self.store._collections["group1"] @@ -36,8 +36,8 @@ def test_save_nodes(self): def test_try_load_store(self): # Set up initial data to be loaded - node1 = DocNode(uid="1", text="text1", ntype="group1", parent=None) - node2 = DocNode(uid="2", text="text2", ntype="group1", parent=node1) + node1 = DocNode(uid="1", text="text1", group="group1", parent=None) + node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) self.store.add_nodes("group1", [node1, node2]) # Reset store and load from "persistent" storage From d30ccbfad988750a169d2e8d1f1c58b3fc2d6816 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 19 Jul 2024 16:14:05 +0800 Subject: [PATCH 12/16] placeholder and update in runtime --- lazyllm/tools/rag/index.py | 13 +++++++++---- lazyllm/tools/rag/store.py | 40 ++++++++++++++++++++++---------------- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index 8d129d84..e4cbc706 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -1,5 +1,5 @@ -from typing import List -from .store import DocNode +from typing import List, Callable +from .store import DocNode, BaseStore import numpy as np @@ -8,8 +8,9 @@ class DefaultIndex: registered_similarity = dict() - def __init__(self, embed, **kwargs): + def __init__(self, embed: Callable, store: BaseStore, **kwargs): self.embed = embed + self.store = store @classmethod def register_similarity(cls, func=None, mode=None, descend=True): @@ -33,9 +34,13 @@ def query( assert self.embed, "Chosen similarity needs embed model." assert len(query) > 0, "Query should not be empty." query_embedding = self.embed(query) + updated_nodes = [] for node in nodes: if not node.has_embedding(): node.do_embedding(self.embed) + updated_nodes.append(node) + if updated_nodes: + self.store.save_nodes(updated_nodes[0].group, updated_nodes) similarities = [ (node, similarity_func(query_embedding, node.embedding, **kwargs)) for node in nodes @@ -54,7 +59,7 @@ def query( @DefaultIndex.register_similarity(mode="text", descend=True) -def dummy(query, node, **kwargs): +def dummy(query: str, node, **kwargs): return len(node.text) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 3cfe314b..bc9debfe 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -22,15 +22,15 @@ class MetadataMode(str, Enum): class DocNode: def __init__( self, + group: str, uid: Optional[str] = None, text: Optional[str] = None, - ntype: Optional[str] = None, embedding: Optional[List[float]] = None, parent: Optional["DocNode"] = None, ) -> None: self.uid: str = uid if uid else str(uuid.uuid4()) self.text: Optional[str] = text - self.ntype: Optional[str] = ntype + self.group: str = group self.embedding: Optional[List[float]] = embedding or None self._metadata: Dict[str, Any] = {} # Metadata keys that are excluded from text for the embed model. @@ -79,7 +79,7 @@ def get_children_str(self) -> str: def __str__(self) -> str: return ( - f"DocNode(id: {self.uid}, ntype: {self.ntype}, text: {self.get_content()}) parent: " + f"DocNode(id: {self.uid}, group: {self.group}, text: {self.get_content()}) parent: " f"{self.parent.uid if self.parent else None}, children: {self.get_children_str()} " f"is_embed: {self.has_embedding()}" ) @@ -88,10 +88,11 @@ def __repr__(self) -> str: return str(self) def has_embedding(self) -> bool: - return self.embedding and len(self.embedding) > 0 + return self.embedding and self.embedding[0] != -1 # placeholder def do_embedding(self, embed: Callable) -> None: self.embedding = embed(self.text) + self.is_saved = False def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: metadata_str = self.get_metadata_str(mode=metadata_mode).strip() @@ -121,17 +122,20 @@ def get_text(self) -> str: class BaseStore(ABC): - def __init__(self, node_groups: List[str]): + def __init__(self, node_groups: List[str]) -> None: self._store: Dict[str, Dict[str, DocNode]] = { group: {} for group in node_groups } - - def add_nodes(self, group: str, nodes: List[DocNode]): + + def _add_nodes(self, group: str, nodes: List[DocNode]) -> None: if group not in self._store: self._store[group] = {} for node in nodes: self._store[group][node.uid] = node - self.save_nodes(group) + + def add_nodes(self, group: str, nodes: List[DocNode]) -> None: + self._add_nodes(group, nodes) + self.save_nodes(group, nodes) def has_nodes(self, group: str) -> bool: return len(self._store[group]) > 0 @@ -143,7 +147,7 @@ def traverse_nodes(self, group: str) -> List[DocNode]: return list(self._store.get(group, {}).values()) @abstractmethod - def save_nodes(self, group: str) -> None: + def save_nodes(self, group: str, nodes: List[DocNode]) -> None: raise NotImplementedError("Not implemented yet.") @abstractmethod @@ -155,7 +159,7 @@ class MapStore(BaseStore): def __init__(self, node_groups: List[str], *args, **kwargs): super().__init__(node_groups, *args, **kwargs) - def save_nodes(self, group: str) -> None: + def save_nodes(self, group: str, nodes: List[DocNode]) -> None: pass def try_load_store(self) -> None: @@ -174,6 +178,7 @@ def __init__( for group in node_groups } self.embed = embed + self.placeholder_length = len(embed("a")) self.try_load_store() def try_load_store(self) -> None: @@ -194,23 +199,24 @@ def try_load_store(self) -> None: parent_uid = node.parent parent_node = self._find_node_by_uid(parent_uid) node.parent = parent_node - parent_node.children[node.ntype].append(node) + parent_node.children[node.group].append(node) LOG.debug(f"build {group} nodes from chromadb: {nodes_dict.values()}") LOG.success("Successfully Built nodes from chromadb.") - def save_nodes(self, group: str) -> None: - nodes = self.traverse_nodes(group) + def save_nodes(self, group: str, nodes: List[DocNode]) -> None: ids, embeddings, metadatas, documents = [], [], [], [] collection = self._collections.get(group) + assert collection, f"Group {group} is not found in collections {self._collections}" for node in nodes: if node.is_saved: continue if not node.has_embedding(): - node.do_embedding(self.embed) + node.embedding = [-1] * self.placeholder_length ids.append(node.uid) embeddings.append(node.embedding) metadatas.append(self._make_chroma_metadata(node)) documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) + node.is_saved = True if ids: collection.upsert( embeddings=embeddings, @@ -218,7 +224,7 @@ def save_nodes(self, group: str) -> None: metadatas=metadatas, documents=documents, ) - LOG.debug(f"Saved {group} nodes {ids} to chromadb") + LOG.debug(f"Saved {group} nodes {ids} to chromadb.") def _find_node_by_uid(self, uid: str) -> Optional[DocNode]: for nodes_by_category in self._store.values(): @@ -233,7 +239,7 @@ def _build_nodes_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: node = DocNode( uid=uid, text=results["documents"][i], - ntype=chroma_metadata["ntype"], + group=chroma_metadata["group"], embedding=results["embeddings"][i], parent=chroma_metadata["parent"], ) @@ -243,7 +249,7 @@ def _build_nodes_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]: metadata = { - "ntype": node.ntype, + "group": node.group, "parent": node.parent.uid if node.parent else "", } return metadata From c19ba5e2d3a45ceb409c47e777f1d3f7f0ceb716 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 19 Jul 2024 16:15:27 +0800 Subject: [PATCH 13/16] fix lint issue --- lazyllm/tools/rag/store.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index bc9debfe..b34ba11b 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -88,7 +88,7 @@ def __repr__(self) -> str: return str(self) def has_embedding(self) -> bool: - return self.embedding and self.embedding[0] != -1 # placeholder + return self.embedding and self.embedding[0] != -1 # placeholder def do_embedding(self, embed: Callable) -> None: self.embedding = embed(self.text) @@ -126,7 +126,7 @@ def __init__(self, node_groups: List[str]) -> None: self._store: Dict[str, Dict[str, DocNode]] = { group: {} for group in node_groups } - + def _add_nodes(self, group: str, nodes: List[DocNode]) -> None: if group not in self._store: self._store[group] = {} @@ -206,7 +206,9 @@ def try_load_store(self) -> None: def save_nodes(self, group: str, nodes: List[DocNode]) -> None: ids, embeddings, metadatas, documents = [], [], [], [] collection = self._collections.get(group) - assert collection, f"Group {group} is not found in collections {self._collections}" + assert ( + collection + ), f"Group {group} is not found in collections {self._collections}" for node in nodes: if node.is_saved: continue From 68147d301146f6d6c7f8cb394507a379eeb13137 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 19 Jul 2024 16:29:02 +0800 Subject: [PATCH 14/16] fix: unit test --- lazyllm/tools/rag/store.py | 6 +++--- tests/basic_tests/test_store.py | 3 --- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index b34ba11b..acc14eaa 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -22,15 +22,15 @@ class MetadataMode(str, Enum): class DocNode: def __init__( self, - group: str, uid: Optional[str] = None, text: Optional[str] = None, + group: Optional[str] = None, embedding: Optional[List[float]] = None, parent: Optional["DocNode"] = None, ) -> None: self.uid: str = uid if uid else str(uuid.uuid4()) self.text: Optional[str] = text - self.group: str = group + self.group: Optional[str] = group self.embedding: Optional[List[float]] = embedding or None self._metadata: Dict[str, Any] = {} # Metadata keys that are excluded from text for the embed model. @@ -190,7 +190,7 @@ def try_load_store(self) -> None: for group in self._collections.keys(): results = self._peek_all_documents(group) nodes = self._build_nodes_from_chroma(results) - self.add_nodes(group, nodes) + self._add_nodes(group, nodes) # Rebuild relationships for group, nodes_dict in self._store.items(): diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index bfc8d1a2..52123ae2 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -28,11 +28,8 @@ def test_save_nodes(self): node1 = DocNode(uid="1", text="text1", group="type1") node2 = DocNode(uid="2", text="text2", group="type2") self.store.add_nodes("group1", [node1, node2]) - self.store.save_nodes("group1") collection = self.store._collections["group1"] self.assertEqual(collection.peek(collection.count())["ids"], ["1", "2"]) - self.assertTrue(node1.has_embedding()) - self.assertTrue(node2.has_embedding()) def test_try_load_store(self): # Set up initial data to be loaded From b16e0fa437363259ca18374874135eb9175e92ba Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 19 Jul 2024 17:45:49 +0800 Subject: [PATCH 15/16] optimize: remove duplicate "for" statement --- lazyllm/tools/rag/index.py | 5 +---- lazyllm/tools/rag/store.py | 9 ++++----- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index e4cbc706..27a333ae 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -34,13 +34,10 @@ def query( assert self.embed, "Chosen similarity needs embed model." assert len(query) > 0, "Query should not be empty." query_embedding = self.embed(query) - updated_nodes = [] for node in nodes: if not node.has_embedding(): node.do_embedding(self.embed) - updated_nodes.append(node) - if updated_nodes: - self.store.save_nodes(updated_nodes[0].group, updated_nodes) + self.store.try_save_nodes(nodes[0].group, nodes) similarities = [ (node, similarity_func(query_embedding, node.embedding, **kwargs)) for node in nodes diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index acc14eaa..4824428c 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -135,7 +135,7 @@ def _add_nodes(self, group: str, nodes: List[DocNode]) -> None: def add_nodes(self, group: str, nodes: List[DocNode]) -> None: self._add_nodes(group, nodes) - self.save_nodes(group, nodes) + self.try_save_nodes(group, nodes) def has_nodes(self, group: str) -> bool: return len(self._store[group]) > 0 @@ -147,7 +147,7 @@ def traverse_nodes(self, group: str) -> List[DocNode]: return list(self._store.get(group, {}).values()) @abstractmethod - def save_nodes(self, group: str, nodes: List[DocNode]) -> None: + def try_save_nodes(self, group: str, nodes: List[DocNode]) -> None: raise NotImplementedError("Not implemented yet.") @abstractmethod @@ -159,7 +159,7 @@ class MapStore(BaseStore): def __init__(self, node_groups: List[str], *args, **kwargs): super().__init__(node_groups, *args, **kwargs) - def save_nodes(self, group: str, nodes: List[DocNode]) -> None: + def try_save_nodes(self, group: str, nodes: List[DocNode]) -> None: pass def try_load_store(self) -> None: @@ -177,7 +177,6 @@ def __init__( group: self._db_client.get_or_create_collection(group) for group in node_groups } - self.embed = embed self.placeholder_length = len(embed("a")) self.try_load_store() @@ -203,7 +202,7 @@ def try_load_store(self) -> None: LOG.debug(f"build {group} nodes from chromadb: {nodes_dict.values()}") LOG.success("Successfully Built nodes from chromadb.") - def save_nodes(self, group: str, nodes: List[DocNode]) -> None: + def try_save_nodes(self, group: str, nodes: List[DocNode]) -> None: ids, embeddings, metadatas, documents = [], [], [], [] collection = self._collections.get(group) assert ( From 30e7417aef6eb61c71b29d55b7a4ca14aafa54ed Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Mon, 22 Jul 2024 16:30:41 +0800 Subject: [PATCH 16/16] fix: optimize placeholder --- lazyllm/tools/rag/store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 4824428c..2f4cf3b3 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -177,7 +177,7 @@ def __init__( group: self._db_client.get_or_create_collection(group) for group in node_groups } - self.placeholder_length = len(embed("a")) + self._placeholder = [-1] * len(embed("a")) self.try_load_store() def try_load_store(self) -> None: @@ -212,7 +212,7 @@ def try_save_nodes(self, group: str, nodes: List[DocNode]) -> None: if node.is_saved: continue if not node.has_embedding(): - node.embedding = [-1] * self.placeholder_length + node.embedding = self._placeholder ids.append(node.uid) embeddings.append(node.embedding) metadatas.append(self._make_chroma_metadata(node))