Skip to content

Commit

Permalink
add chromadb store (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
yewentao256 authored Jul 23, 2024
1 parent e70f706 commit 1efc667
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 114 deletions.
12 changes: 6 additions & 6 deletions lazyllm/tools/rag/data_loaders.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import List
from .store import DocNode
from .store import DocNode, LAZY_ROOT_NAME
from lazyllm import LOG


class DirectoryReader:
def __init__(self, input_files: List[str]):
self.input_files = input_files

def load_data(self, ntype: str = "root") -> List["DocNode"]:
def load_data(self, group: str = LAZY_ROOT_NAME) -> List["DocNode"]:
from llama_index.core import SimpleDirectoryReader

llama_index_docs = SimpleDirectoryReader(
Expand All @@ -17,11 +17,11 @@ def load_data(self, ntype: str = "root") -> List["DocNode"]:
for doc in llama_index_docs:
node = DocNode(
text=doc.text,
ntype=ntype,
metadata=doc.metadata,
excluded_embed_metadata_keys=doc.excluded_embed_metadata_keys,
excluded_llm_metadata_keys=doc.excluded_llm_metadata_keys,
group=group,
)
node.metadata = doc.metadata
node.excluded_embed_metadata_keys = doc.excluded_embed_metadata_keys
node.excluded_llm_metadata_keys = doc.excluded_llm_metadata_keys
nodes.append(node)
if not nodes:
LOG.warning(
Expand Down
95 changes: 58 additions & 37 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,50 @@
from functools import partial
import ast
from functools import partial, wraps
from typing import Dict, List, Optional, Set
from lazyllm import ModuleBase, LOG
from lazyllm import ModuleBase, LOG, config, once_flag, call_once
from lazyllm.common import LazyLlmRequest
from .transform import FuncNodeTransform, SentenceSplitter
from .store import MapStore, DocNode
from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME
from .data_loaders import DirectoryReader
from .index import DefaultIndex


def embed_wrapper(func):
if not func:
return None

@wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
return ast.literal_eval(result)

return wrapper


class DocImplV2:
def __init__(self, embed, doc_files=Optional[List[str]], **kwargs):
super().__init__()
self.directory_reader = DirectoryReader(input_files=doc_files)
self.node_groups: Dict[str, Dict] = {}
self.node_groups: Dict[str, Dict] = {LAZY_ROOT_NAME: {}}
self.create_node_group_default()
self.store = MapStore()
self.index = DefaultIndex(embed)
self.embed = embed_wrapper(embed)
self.init_flag = once_flag()

def _lazy_init(self) -> None:
rag_store = config["rag_store"]
if rag_store == "map":
self.store = MapStore(node_groups=self.node_groups.keys())
elif rag_store == "chroma":
self.store = ChromadbStore(
node_groups=self.node_groups.keys(), embed=self.embed
)
else:
raise NotImplementedError(f"Not implemented store type for {rag_store}")
self.index = DefaultIndex(self.embed, self.store)
if not self.store.has_nodes(LAZY_ROOT_NAME):
docs = self.directory_reader.load_data()
self.store.add_nodes(LAZY_ROOT_NAME, docs)
LOG.debug(f"building {LAZY_ROOT_NAME} nodes: {docs}")

def create_node_group_default(self):
self.create_node_group(
Expand All @@ -38,7 +67,7 @@ def create_node_group_default(self):
)

def create_node_group(
self, name, transform, parent="_lazyllm_root", **kwargs
self, name, transform, parent=LAZY_ROOT_NAME, **kwargs
) -> None:
if name in self.node_groups:
LOG.warning(f"Duplicate group name: {name}")
Expand Down Expand Up @@ -67,25 +96,17 @@ def _dynamic_create_nodes(self, group_name) -> None:
if self.store.has_nodes(group_name):
return
transform = self._get_transform(group_name)
parent_name = node_group["parent_name"]
self._dynamic_create_nodes(parent_name)

parent_nodes = self.store.traverse_nodes(parent_name)

sub_nodes = transform(parent_nodes, group_name)
self.store.add_nodes(group_name, sub_nodes)
LOG.debug(f"building {group_name} nodes: {sub_nodes}")
parent_nodes = self._get_nodes(node_group["parent_name"])
nodes = transform(parent_nodes, group_name)
self.store.add_nodes(group_name, nodes)
LOG.debug(f"building {group_name} nodes: {nodes}")

def _get_nodes(self, group_name: str) -> List[DocNode]:
# lazy load files, if group isn't set, create the group
if not self.store.has_nodes("_lazyllm_root"):
docs = self.directory_reader.load_data()
self.store.add_nodes("_lazyllm_root", docs)
LOG.debug(f"building _lazyllm_root nodes: {docs}")
self._dynamic_create_nodes(group_name)
return self.store.traverse_nodes(group_name)

def retrieve(self, query, group_name, similarity, index, topk, similarity_kws):
call_once(self.init_flag, self._lazy_init)
if index:
assert index == "default", "we only support default index currently"
if isinstance(query, LazyLlmRequest):
Expand All @@ -94,10 +115,10 @@ def retrieve(self, query, group_name, similarity, index, topk, similarity_kws):
nodes = self._get_nodes(group_name)
return self.index.query(query, nodes, similarity, topk, **similarity_kws)

def _find_parent(self, nodes: List[DocNode], name: str) -> List[DocNode]:
def _find_parent(self, nodes: List[DocNode], group: str) -> List[DocNode]:
def recurse_parents(node: DocNode, visited: Set[DocNode]) -> None:
if node.parent:
if node.parent.ntype == name:
if node.parent.group == group:
visited.add(node.parent)
recurse_parents(node.parent, visited)

Expand All @@ -106,18 +127,18 @@ def recurse_parents(node: DocNode, visited: Set[DocNode]) -> None:
recurse_parents(node, result)
if not result:
LOG.warning(
f"We can not find any nodes for name `{name}`, please check your input"
f"We can not find any nodes for group `{group}`, please check your input"
)
LOG.debug(f"Found parent node for {name}: {result}")
LOG.debug(f"Found parent node for {group}: {result}")
return list(result)

def find_parent(self, name: str) -> List[DocNode]:
return partial(self._find_parent, name=name)
def find_parent(self, group: str) -> List[DocNode]:
return partial(self._find_parent, group=group)

def _find_children(self, nodes: List[DocNode], name: str) -> List[DocNode]:
def _find_children(self, nodes: List[DocNode], group: str) -> List[DocNode]:
def recurse_children(node: DocNode, visited: Set[DocNode]) -> bool:
if name in node.children:
visited.update(node.children[name])
if group in node.children:
visited.update(node.children[group])
return True

found_in_any_child = False
Expand All @@ -134,11 +155,11 @@ def recurse_children(node: DocNode, visited: Set[DocNode]) -> bool:
result = set()

# case when user hasn't used the group before.
_ = self._get_nodes(name)
_ = self._get_nodes(group)

for node in nodes:
if name in node.children:
result.update(node.children[name])
if group in node.children:
result.update(node.children[group])
else:
LOG.log_once(
f"Fetching children that are not in direct relationship might be slower. "
Expand All @@ -149,21 +170,21 @@ def recurse_children(node: DocNode, visited: Set[DocNode]) -> bool:
# Note: the input nodes are the same type
if not recurse_children(node, result):
LOG.warning(
f"Node {node} and its children do not contain any nodes with the name `{name}`. "
f"Node {node} and its children do not contain any nodes with the group `{group}`. "
"Skipping further search in this branch."
)
break

if not result:
LOG.warning(
f"We cannot find any nodes for name `{name}`, please check your input."
f"We cannot find any nodes for group `{group}`, please check your input."
)

LOG.debug(f"Found children nodes for {name}: {result}")
LOG.debug(f"Found children nodes for {group}: {result}")
return list(result)

def find_children(self, name: str) -> List[DocNode]:
return partial(self._find_children, name=name)
def find_children(self, group: str) -> List[DocNode]:
return partial(self._find_children, group=group)


class RetrieverV2(ModuleBase):
Expand Down
24 changes: 17 additions & 7 deletions lazyllm/tools/rag/index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
from typing import List, Callable
from .store import DocNode, BaseStore
import numpy as np


Expand All @@ -7,8 +8,9 @@ class DefaultIndex:

registered_similarity = dict()

def __init__(self, embed, **kwargs):
def __init__(self, embed: Callable, store: BaseStore, **kwargs):
self.embed = embed
self.store = store

@classmethod
def register_similarity(cls, func=None, mode=None, descend=True):
Expand All @@ -18,16 +20,24 @@ def decorator(f):

return decorator(func) if func else decorator

def query(self, query, nodes, similarity_name, topk=None, **kwargs):
def query(
self,
query: str,
nodes: List[DocNode],
similarity_name: str,
topk: int,
**kwargs,
) -> List[DocNode]:
similarity_func, mode, descend = self.registered_similarity[similarity_name]

if mode == "embedding":
assert self.embed, "Chosen similarity needs embed model."
assert len(query) > 0, "Query should not be empty."
query_embedding = ast.literal_eval(self.embed(query))
query_embedding = self.embed(query)
for node in nodes:
if not node.embedding:
node.embedding = ast.literal_eval(self.embed(node.text))
if not node.has_embedding():
node.do_embedding(self.embed)
self.store.try_save_nodes(nodes[0].group, nodes)
similarities = [
(node, similarity_func(query_embedding, node.embedding, **kwargs))
for node in nodes
Expand All @@ -46,7 +56,7 @@ def query(self, query, nodes, similarity_name, topk=None, **kwargs):


@DefaultIndex.register_similarity(mode="text", descend=True)
def dummy(query, node, **kwargs):
def dummy(query: str, node, **kwargs):
return len(node.text)


Expand Down
Loading

0 comments on commit 1efc667

Please sign in to comment.