Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add chromadb store #86

Merged
merged 16 commits into from
Jul 23, 2024
12 changes: 6 additions & 6 deletions lazyllm/tools/rag/data_loaders.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import List
from .store import DocNode
from .store import DocNode, LAZY_ROOT_NAME
from lazyllm import LOG


class DirectoryReader:
def __init__(self, input_files: List[str]):
self.input_files = input_files

def load_data(self, ntype: str = "root") -> List["DocNode"]:
def load_data(self, group: str = LAZY_ROOT_NAME) -> List["DocNode"]:
from llama_index.core import SimpleDirectoryReader

llama_index_docs = SimpleDirectoryReader(
Expand All @@ -17,11 +17,11 @@ def load_data(self, ntype: str = "root") -> List["DocNode"]:
for doc in llama_index_docs:
node = DocNode(
text=doc.text,
ntype=ntype,
metadata=doc.metadata,
excluded_embed_metadata_keys=doc.excluded_embed_metadata_keys,
excluded_llm_metadata_keys=doc.excluded_llm_metadata_keys,
group=group,
)
node.metadata = doc.metadata
node.excluded_embed_metadata_keys = doc.excluded_embed_metadata_keys
node.excluded_llm_metadata_keys = doc.excluded_llm_metadata_keys
nodes.append(node)
if not nodes:
LOG.warning(
Expand Down
95 changes: 58 additions & 37 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,50 @@
from functools import partial
import ast
from functools import partial, wraps
from typing import Dict, List, Optional, Set
from lazyllm import ModuleBase, LOG
from lazyllm import ModuleBase, LOG, config, once_flag, call_once
from lazyllm.common import LazyLlmRequest
from .transform import FuncNodeTransform, SentenceSplitter
from .store import MapStore, DocNode
from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME
from .data_loaders import DirectoryReader
from .index import DefaultIndex


def embed_wrapper(func):
if not func:
return None

@wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
return ast.literal_eval(result)

return wrapper


class DocImplV2:
def __init__(self, embed, doc_files=Optional[List[str]], **kwargs):
super().__init__()
self.directory_reader = DirectoryReader(input_files=doc_files)
self.node_groups: Dict[str, Dict] = {}
self.node_groups: Dict[str, Dict] = {LAZY_ROOT_NAME: {}}
self.create_node_group_default()
self.store = MapStore()
self.index = DefaultIndex(embed)
self.embed = embed_wrapper(embed)
wzh1994 marked this conversation as resolved.
Show resolved Hide resolved
self.init_flag = once_flag()

def _lazy_init(self) -> None:
rag_store = config["rag_store"]
if rag_store == "map":
self.store = MapStore(node_groups=self.node_groups.keys())
elif rag_store == "chroma":
self.store = ChromadbStore(
node_groups=self.node_groups.keys(), embed=self.embed
)
else:
raise NotImplementedError(f"Not implemented store type for {rag_store}")
self.index = DefaultIndex(self.embed, self.store)
if not self.store.has_nodes(LAZY_ROOT_NAME):
docs = self.directory_reader.load_data()
self.store.add_nodes(LAZY_ROOT_NAME, docs)
LOG.debug(f"building {LAZY_ROOT_NAME} nodes: {docs}")

def create_node_group_default(self):
self.create_node_group(
Expand All @@ -38,7 +67,7 @@ def create_node_group_default(self):
)

def create_node_group(
self, name, transform, parent="_lazyllm_root", **kwargs
self, name, transform, parent=LAZY_ROOT_NAME, **kwargs
) -> None:
if name in self.node_groups:
LOG.warning(f"Duplicate group name: {name}")
Expand Down Expand Up @@ -67,25 +96,17 @@ def _dynamic_create_nodes(self, group_name) -> None:
if self.store.has_nodes(group_name):
return
transform = self._get_transform(group_name)
parent_name = node_group["parent_name"]
self._dynamic_create_nodes(parent_name)

parent_nodes = self.store.traverse_nodes(parent_name)

sub_nodes = transform(parent_nodes, group_name)
self.store.add_nodes(group_name, sub_nodes)
LOG.debug(f"building {group_name} nodes: {sub_nodes}")
parent_nodes = self._get_nodes(node_group["parent_name"])
nodes = transform(parent_nodes, group_name)
self.store.add_nodes(group_name, nodes)
LOG.debug(f"building {group_name} nodes: {nodes}")

def _get_nodes(self, group_name: str) -> List[DocNode]:
# lazy load files, if group isn't set, create the group
if not self.store.has_nodes("_lazyllm_root"):
docs = self.directory_reader.load_data()
self.store.add_nodes("_lazyllm_root", docs)
LOG.debug(f"building _lazyllm_root nodes: {docs}")
self._dynamic_create_nodes(group_name)
return self.store.traverse_nodes(group_name)

def retrieve(self, query, group_name, similarity, index, topk, similarity_kws):
call_once(self.init_flag, self._lazy_init)
if index:
assert index == "default", "we only support default index currently"
if isinstance(query, LazyLlmRequest):
Expand All @@ -94,10 +115,10 @@ def retrieve(self, query, group_name, similarity, index, topk, similarity_kws):
nodes = self._get_nodes(group_name)
return self.index.query(query, nodes, similarity, topk, **similarity_kws)

def _find_parent(self, nodes: List[DocNode], name: str) -> List[DocNode]:
def _find_parent(self, nodes: List[DocNode], group: str) -> List[DocNode]:
def recurse_parents(node: DocNode, visited: Set[DocNode]) -> None:
if node.parent:
if node.parent.ntype == name:
if node.parent.group == group:
visited.add(node.parent)
recurse_parents(node.parent, visited)

Expand All @@ -106,18 +127,18 @@ def recurse_parents(node: DocNode, visited: Set[DocNode]) -> None:
recurse_parents(node, result)
if not result:
LOG.warning(
f"We can not find any nodes for name `{name}`, please check your input"
f"We can not find any nodes for group `{group}`, please check your input"
)
LOG.debug(f"Found parent node for {name}: {result}")
LOG.debug(f"Found parent node for {group}: {result}")
return list(result)

def find_parent(self, name: str) -> List[DocNode]:
return partial(self._find_parent, name=name)
def find_parent(self, group: str) -> List[DocNode]:
return partial(self._find_parent, group=group)

def _find_children(self, nodes: List[DocNode], name: str) -> List[DocNode]:
def _find_children(self, nodes: List[DocNode], group: str) -> List[DocNode]:
def recurse_children(node: DocNode, visited: Set[DocNode]) -> bool:
if name in node.children:
visited.update(node.children[name])
if group in node.children:
visited.update(node.children[group])
return True

found_in_any_child = False
Expand All @@ -134,11 +155,11 @@ def recurse_children(node: DocNode, visited: Set[DocNode]) -> bool:
result = set()

# case when user hasn't used the group before.
_ = self._get_nodes(name)
_ = self._get_nodes(group)

for node in nodes:
if name in node.children:
result.update(node.children[name])
if group in node.children:
result.update(node.children[group])
else:
LOG.log_once(
f"Fetching children that are not in direct relationship might be slower. "
Expand All @@ -149,21 +170,21 @@ def recurse_children(node: DocNode, visited: Set[DocNode]) -> bool:
# Note: the input nodes are the same type
if not recurse_children(node, result):
LOG.warning(
f"Node {node} and its children do not contain any nodes with the name `{name}`. "
f"Node {node} and its children do not contain any nodes with the group `{group}`. "
"Skipping further search in this branch."
)
break

if not result:
LOG.warning(
f"We cannot find any nodes for name `{name}`, please check your input."
f"We cannot find any nodes for group `{group}`, please check your input."
)

LOG.debug(f"Found children nodes for {name}: {result}")
LOG.debug(f"Found children nodes for {group}: {result}")
return list(result)

def find_children(self, name: str) -> List[DocNode]:
return partial(self._find_children, name=name)
def find_children(self, group: str) -> List[DocNode]:
return partial(self._find_children, group=group)


class RetrieverV2(ModuleBase):
Expand Down
24 changes: 17 additions & 7 deletions lazyllm/tools/rag/index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
from typing import List, Callable
from .store import DocNode, BaseStore
import numpy as np


Expand All @@ -7,8 +8,9 @@ class DefaultIndex:

registered_similarity = dict()

def __init__(self, embed, **kwargs):
def __init__(self, embed: Callable, store: BaseStore, **kwargs):
self.embed = embed
self.store = store

@classmethod
def register_similarity(cls, func=None, mode=None, descend=True):
Expand All @@ -18,16 +20,24 @@ def decorator(f):

return decorator(func) if func else decorator

def query(self, query, nodes, similarity_name, topk=None, **kwargs):
def query(
self,
query: str,
nodes: List[DocNode],
similarity_name: str,
topk: int,
**kwargs,
) -> List[DocNode]:
similarity_func, mode, descend = self.registered_similarity[similarity_name]

if mode == "embedding":
assert self.embed, "Chosen similarity needs embed model."
assert len(query) > 0, "Query should not be empty."
query_embedding = ast.literal_eval(self.embed(query))
query_embedding = self.embed(query)
for node in nodes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没有一个异步或并行化的处理么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

异步或者并行我理解是性能方面的提升了,需要的话可以加一个,不过想着不影响功能就先没加和测试

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这边记下了,P0做完后还有空间的话就加上

if not node.embedding:
node.embedding = ast.literal_eval(self.embed(node.text))
if not node.has_embedding():
node.do_embedding(self.embed)
self.store.try_save_nodes(nodes[0].group, nodes)
similarities = [
(node, similarity_func(query_embedding, node.embedding, **kwargs))
for node in nodes
Expand All @@ -46,7 +56,7 @@ def query(self, query, nodes, similarity_name, topk=None, **kwargs):


@DefaultIndex.register_similarity(mode="text", descend=True)
def dummy(query, node, **kwargs):
def dummy(query: str, node, **kwargs):
return len(node.text)


Expand Down
Loading