Skip to content

Commit

Permalink
Add MilvusStore and re-implement MapStore and ChromadbStore, su…
Browse files Browse the repository at this point in the history
…pport multi index for one store (#322)

Co-authored-by: lwj-st <[email protected]>
  • Loading branch information
ouonline and lwj-st authored Nov 8, 2024
1 parent 8654413 commit 5d8a157
Show file tree
Hide file tree
Showing 51 changed files with 1,523 additions and 640 deletions.
2 changes: 1 addition & 1 deletion LazyLLM-Env
Submodule LazyLLM-Env updated 1 files
+169 −1 poetry.lock
70 changes: 70 additions & 0 deletions examples/rag_map_store_with_milvus_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-

import os
import lazyllm
from lazyllm import bind
import tempfile

def run(query):
_, store_file = tempfile.mkstemp(suffix=".db")

milvus_store_conf = {
'type': 'map',
'indices': {
'smart_embedding_index': {
'backend': 'milvus',
'kwargs': {
'uri': store_file,
'embedding_index_type': 'HNSW',
'embedding_metric_type': 'COSINE',
},
},
},
}

documents = lazyllm.Document(dataset_path="rag_master",
embed=lazyllm.TrainableModule("bge-large-zh-v1.5"),
manager=False,
store_conf=milvus_store_conf)

documents.create_node_group(name="sentences",
transform=lambda s: '。'.split(s))

prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task.'\
' In this task, you need to provide your answer based on the given context and question.'

with lazyllm.pipeline() as ppl:
with lazyllm.parallel().sum as ppl.prl:
ppl.prl.retriever1 = lazyllm.Retriever(doc=documents,
group_name="CoarseChunk",
similarity="bm25_chinese",
topk=3)
ppl.prl.retriever2 = lazyllm.Retriever(doc=documents,
group_name="sentences",
similarity="cosine",
topk=3)

ppl.reranker = lazyllm.Reranker(name='ModuleReranker',
model="bge-reranker-large",
topk=1,
output_format='content',
join=True) | bind(query=ppl.input)

ppl.formatter = (
lambda nodes, query: dict(context_str=nodes, query=query)
) | bind(query=ppl.input)

ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt(
lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str']))

rag = lazyllm.ActionModule(ppl)
rag.start()
res = rag(query)

os.remove(store_file)

return res

if __name__ == '__main__':
res = run('何为天道?')
print(f'answer: {res}')
65 changes: 65 additions & 0 deletions examples/rag_milvus_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-

import os
import lazyllm
from lazyllm import bind
import tempfile

def run(query):
_, store_file = tempfile.mkstemp(suffix=".db")

milvus_store_conf = {
'type': 'milvus',
'kwargs': {
'uri': store_file,
'embedding_index_type': 'HNSW',
'embedding_metric_type': 'COSINE',
},
}

documents = lazyllm.Document(dataset_path="rag_master",
embed=lazyllm.TrainableModule("bge-large-zh-v1.5"),
manager=False,
store_conf=milvus_store_conf)

documents.create_node_group(name="sentences",
transform=lambda s: '。'.split(s))

prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task.'\
' In this task, you need to provide your answer based on the given context and question.'

with lazyllm.pipeline() as ppl:
with lazyllm.parallel().sum as ppl.prl:
ppl.prl.retriever1 = lazyllm.Retriever(doc=documents,
group_name="CoarseChunk",
similarity="bm25_chinese",
topk=3)
ppl.prl.retriever2 = lazyllm.Retriever(doc=documents,
group_name="sentences",
similarity="cosine",
topk=3)

ppl.reranker = lazyllm.Reranker(name='ModuleReranker',
model="bge-reranker-large",
topk=1,
output_format='content',
join=True) | bind(query=ppl.input)

ppl.formatter = (
lambda nodes, query: dict(context_str=nodes, query=query)
) | bind(query=ppl.input)

ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt(
lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str']))

rag = lazyllm.ActionModule(ppl)
rag.start()
res = rag(query)

os.remove(store_file)

return res

if __name__ == '__main__':
res = run('何为天道?')
print(f'answer: {res}')
3 changes: 2 additions & 1 deletion lazyllm/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .registry import LazyLLMRegisterMetaClass, _get_base_cls_from_registry, Register
from .common import package, kwargs, arguments, LazyLLMCMD, timeout, final, ReadOnlyWrapper, DynamicDescriptor
from .common import package, kwargs, arguments, LazyLLMCMD, timeout, final, ReadOnlyWrapper, DynamicDescriptor, override
from .common import FlatList, Identity, ResultCollector, ArgsDict, CaseInsensitiveDict
from .common import ReprRule, make_repr, modify_repr
from .common import once_flag, call_once, once_wrapper, singleton, reset_on_pickle
Expand Down Expand Up @@ -39,6 +39,7 @@
'package',
'kwargs',
'arguments',
'override',

# option
'Option',
Expand Down
6 changes: 6 additions & 0 deletions lazyllm/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
_F = typing.TypeVar("_F", bound=Callable[..., Any])
def final(f: _F) -> _F: return f

try:
from typing import override
except ImportError:
def override(func: Callable):
return func


class FlatList(list):
def absorb(self, item):
Expand Down
4 changes: 2 additions & 2 deletions lazyllm/tools/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from .retriever import Retriever
from .rerank import Reranker, register_reranker
from .transform import SentenceSplitter, LLMParser, NodeTransform, TransformArgs, AdaptiveTransform
from .index import register_similarity
from .store import DocNode
from .similarity import register_similarity
from .doc_node import DocNode
from .readers import (PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader,
MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader)
from .dataReader import SimpleDirectoryReader
Expand Down
180 changes: 180 additions & 0 deletions lazyllm/tools/rag/chroma_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from typing import Any, Dict, List, Optional, Callable, Set
import chromadb
from lazyllm import LOG
from lazyllm.common import override
from chromadb.api.models.Collection import Collection
from .store_base import StoreBase, LAZY_ROOT_NAME
from .doc_node import DocNode
from .index_base import IndexBase
from .utils import _FileNodeIndex
from .default_index import DefaultIndex
from .map_store import MapStore
import pickle
import base64

# ---------------------------------------------------------------------------- #

class ChromadbStore(StoreBase):
def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Callable],
embed_dims: Dict[str, int], dir: str, **kwargs) -> None:
self._db_client = chromadb.PersistentClient(path=dir)
LOG.success(f"Initialzed chromadb in path: {dir}")
node_groups = list(group_embed_keys.keys())
self._collections: Dict[str, Collection] = {
group: self._db_client.get_or_create_collection(group)
for group in node_groups
}

self._map_store = MapStore(node_groups=node_groups, embed=embed)
self._load_store(embed_dims)

self._name2index = {
'default': DefaultIndex(embed, self._map_store),
'file_node_map': _FileNodeIndex(),
}

@override
def update_nodes(self, nodes: List[DocNode]) -> None:
self._map_store.update_nodes(nodes)
self._save_nodes(nodes)

@override
def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None:
if uids:
self._delete_group_nodes(group_name, uids)
else:
self._db_client.delete_collection(name=group_name)
return self._map_store.remove_nodes(group_name, uids)

@override
def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]:
return self._map_store.get_nodes(group_name, uids)

@override
def is_group_active(self, name: str) -> bool:
return self._map_store.is_group_active(name)

@override
def all_groups(self) -> List[str]:
return self._map_store.all_groups()

@override
def query(self, *args, **kwargs) -> List[DocNode]:
return self.get_index('default').query(*args, **kwargs)

@override
def register_index(self, type: str, index: IndexBase) -> None:
self._name2index[type] = index

@override
def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]:
if type is None:
type = 'default'
return self._name2index.get(type)

def _load_store(self, embed_dims: Dict[str, int]) -> None:
if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]:
LOG.info("No persistent data found, skip the rebuilding phrase.")
return

# Restore all nodes
for group in self._collections.keys():
results = self._peek_all_documents(group)
nodes = self._build_nodes_from_chroma(results, embed_dims)
self._map_store.update_nodes(nodes)

# Rebuild relationships
for group_name in self._map_store.all_groups():
nodes = self._map_store.get_nodes(group_name)
for node in nodes:
if node.parent:
parent_uid = node.parent
parent_node = self._map_store.find_node_by_uid(parent_uid)
node.parent = parent_node
parent_node.children[node.group].append(node)
LOG.debug(f"build {group} nodes from chromadb: {nodes}")
LOG.success("Successfully Built nodes from chromadb.")

def _save_nodes(self, nodes: List[DocNode]) -> None:
if not nodes:
return
# Note: It's caller's duty to make sure this batch of nodes has the same group.
group = nodes[0].group
ids, embeddings, metadatas, documents = [], [], [], []
collection = self._collections.get(group)
assert (
collection
), f"Group {group} is not found in collections {self._collections}"
for node in nodes:
metadata = self._make_chroma_metadata(node)
ids.append(node.uid)
embeddings.append([0]) # we don't use chroma for retrieving
metadatas.append(metadata)
documents.append(node.get_text())
if ids:
collection.upsert(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
LOG.debug(f"Saved {group} nodes {ids} to chromadb.")

def _delete_group_nodes(self, group_name: str, uids: List[str]) -> None:
collection = self._collections.get(group_name)
if collection:
collection.delete(ids=uids)

def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dims: Dict[str, int]) -> List[DocNode]:
nodes: List[DocNode] = []
for i, uid in enumerate(results['ids']):
chroma_metadata = results['metadatas'][i]

parent = chroma_metadata['parent']
fields = pickle.loads(base64.b64decode(chroma_metadata['fields'].encode('utf-8')))\
if parent else None

node = DocNode(
uid=uid,
text=results["documents"][i],
group=chroma_metadata["group"],
embedding=pickle.loads(base64.b64decode(chroma_metadata['embedding'].encode('utf-8'))),
parent=parent,
fields=fields,
)

if node.embedding:
# convert sparse embedding to List[float]
new_embedding_dict = {}
for key, embedding in node.embedding.items():
if isinstance(embedding, dict):
dim = embed_dims.get(key)
if not dim:
raise ValueError(f'dim of embed [{key}] is not determined.')
new_embedding = [0] * dim
for idx, val in embedding.items():
new_embedding[int(idx)] = val
new_embedding_dict[key] = new_embedding
else:
new_embedding_dict[key] = embedding
node.embedding = new_embedding_dict

nodes.append(node)
return nodes

def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]:
metadata = {
"group": node.group,
"parent": node.parent.uid if node.parent else "",
"embedding": base64.b64encode(pickle.dumps(node.embedding)).decode('utf-8'),
}

if node.parent:
metadata["fields"] = base64.b64encode(pickle.dumps(node.fields)).decode('utf-8')

return metadata

def _peek_all_documents(self, group: str) -> Dict[str, List]:
assert group in self._collections, f"group {group} not found."
collection = self._collections[group]
return collection.peek(collection.count())
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/component/bm25.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Tuple
from ..store import DocNode
from ..doc_node import DocNode
import bm25s
import Stemmer
from lazyllm.thirdparty import jieba
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/dataReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pathlib import Path, PurePosixPath, PurePath
from fsspec import AbstractFileSystem
from lazyllm import ModuleBase, LOG
from .store import DocNode
from .doc_node import DocNode
from .readers import (ReaderBase, PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader,
EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader,
get_default_fs, is_default_fs)
Expand Down
3 changes: 2 additions & 1 deletion lazyllm/tools/rag/data_loaders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional, Dict
from .store import DocNode, LAZY_ROOT_NAME
from .doc_node import DocNode
from .store_base import LAZY_ROOT_NAME
from lazyllm import LOG
from .dataReader import SimpleDirectoryReader

Expand Down
Loading

0 comments on commit 5d8a157

Please sign in to comment.