Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MilvusStore and re-implement MapStore and ChromadbStore, support multi index for one store #322

Merged
merged 66 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
b58369f
add pymilvus pkg
lwj-st Oct 29, 2024
2fa48d0
store and index api breaking changes
ouonline Oct 29, 2024
d4217a6
s
ouonline Oct 29, 2024
827755f
s
ouonline Oct 29, 2024
b8fa14b
s
ouonline Oct 29, 2024
71161f6
s
ouonline Oct 29, 2024
5c30f5f
s
ouonline Oct 29, 2024
6e09bb3
s
ouonline Oct 29, 2024
71f2df2
s
ouonline Oct 29, 2024
7cccf6e
Merge branch 'main' into index-api
ouonline Oct 29, 2024
8befadc
s
ouonline Oct 30, 2024
93e24a4
s
ouonline Oct 30, 2024
803374a
s
ouonline Oct 30, 2024
bf03c15
s
ouonline Oct 30, 2024
138c149
s
ouonline Oct 30, 2024
9eb071d
s
ouonline Oct 31, 2024
debe8fe
s
ouonline Oct 31, 2024
d92379a
s
ouonline Oct 31, 2024
25fd4f9
s
ouonline Oct 31, 2024
2ad88e4
s
ouonline Oct 31, 2024
cd2fa3b
s
ouonline Oct 31, 2024
6ba7f51
s
ouonline Oct 31, 2024
a05996d
s
ouonline Oct 31, 2024
d87cbfd
s
ouonline Oct 31, 2024
1ec44ec
s
ouonline Oct 31, 2024
f7c5327
s
ouonline Oct 31, 2024
8b4066e
s
ouonline Oct 31, 2024
97dc0f0
s
ouonline Nov 1, 2024
1ca6ca8
s
ouonline Nov 1, 2024
1ad9af0
s
ouonline Nov 1, 2024
c2316f4
s
ouonline Nov 1, 2024
efcc5ab
s
ouonline Nov 1, 2024
7d6641a
s
ouonline Nov 1, 2024
a3155c9
s
ouonline Nov 1, 2024
805ae21
s
ouonline Nov 1, 2024
bbec237
s
ouonline Nov 1, 2024
d90333f
s
ouonline Nov 1, 2024
dcac229
s
ouonline Nov 1, 2024
a534624
s
ouonline Nov 4, 2024
aaa7f67
s
ouonline Nov 4, 2024
bfac3a8
s
ouonline Nov 4, 2024
97ac0f0
s
ouonline Nov 4, 2024
0e9f11d
s
ouonline Nov 4, 2024
b532e37
Merge branch 'main' into index-api
ouonline Nov 4, 2024
5b1f8c6
s
ouonline Nov 4, 2024
0731507
s
ouonline Nov 4, 2024
7e2aab2
s
ouonline Nov 4, 2024
91f8d4d
s
ouonline Nov 4, 2024
947ee92
s
ouonline Nov 4, 2024
0edd725
s
ouonline Nov 4, 2024
a6b1abb
s
ouonline Nov 4, 2024
cb04808
s
ouonline Nov 4, 2024
f895488
s
ouonline Nov 4, 2024
a8ac8e5
s
ouonline Nov 4, 2024
ed2a66c
s
ouonline Nov 4, 2024
c6e36da
s
ouonline Nov 4, 2024
5ae2474
s
ouonline Nov 4, 2024
38f873f
s
ouonline Nov 4, 2024
4fa60dc
review begins
ouonline Nov 5, 2024
13a9197
review2
ouonline Nov 5, 2024
91b4c4d
s
ouonline Nov 5, 2024
f0cdd57
s
ouonline Nov 6, 2024
581d30b
review3
ouonline Nov 6, 2024
3f3257f
review4
ouonline Nov 6, 2024
9cc8752
review5
ouonline Nov 7, 2024
b724bfc
review6
ouonline Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions lazyllm/tools/rag/chroma_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,28 @@
from .default_index import DefaultIndex
from .map_store import MapStore
import pickle
import base64

# ---------------------------------------------------------------------------- #

class ChromadbStore(StoreBase):
def __init__(self, dir: str, embed: Dict[str, Callable], embed_dim: Dict[str, int], **kwargs) -> None:
def __init__(self, dir: str, node_groups: List[str], embed: Dict[str, Callable],
embed_dim: Dict[str, int], **kwargs) -> None:
self._db_client = chromadb.PersistentClient(path=dir)
LOG.success(f"Initialzed chromadb in path: {dir}")
self._collections: Dict[str, Collection] = {}
self._collections: Dict[str, Collection] = {
group: self._db_client.get_or_create_collection(group)
for group in node_groups
}

self._map_store = MapStore(node_groups=node_groups, embed=embed)
self._load_store(embed_dim)

self._name2index = {
'default': DefaultIndex(embed, self._map_store),
'file_node_map': _FileNodeIndex(),
}

self._map_store = MapStore(embed=embed)
self._load_store(embed_dim)

@override
def update_nodes(self, nodes: List[DocNode]) -> None:
self._map_store.update_nodes(nodes)
Expand All @@ -52,11 +57,6 @@ def is_group_active(self, name: str) -> bool:
def all_groups(self) -> List[str]:
return self._map_store.all_groups()

@override
def add_group(self, name: str, embed_keys: Optional[List[str]] = None) -> None:
self._collections[name] = self._db_client.get_or_create_collection(name)
self._map_store.add_group(name, embed_keys)

@override
def query(self, *args, **kwargs) -> List[DocNode]:
return self.get_index('default').query(*args, **kwargs)
Expand Down Expand Up @@ -133,13 +133,14 @@ def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dim: Dict[str
chroma_metadata = results['metadatas'][i]

parent = chroma_metadata['parent']
fields = pickle.loads(chroma_metadata['fields']) if parent else None
fields = pickle.loads(base64.b64decode(chroma_metadata['fields'].encode('utf-8')))\
if parent else None

node = DocNode(
uid=uid,
text=results["documents"][i],
group=chroma_metadata["group"],
embedding=pickle.loads(chroma_metadata['embedding']),
embedding=pickle.loads(base64.b64decode(chroma_metadata['embedding'].encode('utf-8'))),
parent=parent,
fields=fields,
)
Expand Down Expand Up @@ -168,11 +169,11 @@ def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]:
metadata = {
"group": node.group,
"parent": node.parent.uid if node.parent else "",
"embedding": pickle.dumps(node.embedding),
"embedding": base64.b64encode(pickle.dumps(node.embedding)).decode('utf-8'),
}

if node.parent:
metadata["fields"] = pickle.dumps(node.fields)
metadata["fields"] = base64.b64encode(pickle.dumps(node.fields)).decode('utf-8')

return metadata

Expand Down
8 changes: 8 additions & 0 deletions lazyllm/tools/rag/doc_field_desc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Optional

class DocFieldDesc:
DTYPE_VARCHAR = 0

def __init__(self, data_type: int = DTYPE_VARCHAR, max_length: Optional[int] = 65535):
self.data_type = data_type
self.max_length = max_length
6 changes: 0 additions & 6 deletions lazyllm/tools/rag/doc_field_info.py

This file was deleted.

22 changes: 11 additions & 11 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .doc_node import DocNode
from .data_loaders import DirectoryReader
from .utils import DocListManager
from .doc_field_info import DocFieldInfo
from .doc_field_desc import DocFieldDesc
import threading
import time

Expand All @@ -38,7 +38,7 @@ class DocImpl:

def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = None,
doc_files: Optional[str] = None, kb_group_name: Optional[str] = None,
fields_info: Dict[str, DocFieldInfo] = None, store_conf: Optional[Dict] = None):
fields_desc: Dict[str, DocFieldDesc] = None, store_conf: Optional[Dict] = None):
super().__init__()
assert (dlm is None) ^ (doc_files is None), 'Only one of dataset_path or doc_files should be provided'
self._local_file_reader: Dict[str, Callable] = {}
Expand All @@ -48,8 +48,9 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N
self.node_groups: Dict[str, Dict] = {LAZY_ROOT_NAME: {}}
self.embed = {k: embed_wrapper(e) for k, e in embed.items()}
self._embed_dim = None
self._fields_info = fields_info
self._fields_desc = fields_desc
self.store = store_conf # NOTE: will be initialized in _lazy_init()
self._activated_embeddings = set()
ouonline marked this conversation as resolved.
Show resolved Hide resolved

@once_wrapper(reset_on_pickle=True)
def _lazy_init(self) -> None:
Expand All @@ -71,7 +72,6 @@ def _lazy_init(self) -> None:
raise ValueError(f'store type [{type(self.store)}] is not a dict.')

if not self.store.is_group_active(LAZY_ROOT_NAME):
self.store.add_group(name=LAZY_ROOT_NAME, fields_info=self._fields_info, embed_keys={})
ids, pathes = self._list_files()
root_nodes = self._reader.load_data(pathes)
self.store.update_nodes(root_nodes)
Expand All @@ -94,11 +94,14 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase:
raise ValueError('`kwargs` in store conf is not a dict.')

if store_type == "map":
store = MapStore(embed=self.embed, **kwargs)
store = MapStore(node_groups=self.node_groups, embed=self.embed, **kwargs)
ouonline marked this conversation as resolved.
Show resolved Hide resolved
elif store_type == "chroma":
store = ChromadbStore(embed_dim=self.embed_dim, embed=self.embed, **kwargs)
store = ChromadbStore(node_groups=self.node_groups, embed=self.embed,
embed_dim=self.embed_dim, **kwargs)
elif store_type == "milvus":
store = MilvusStore(embed=self.embed, fields_info=self._fields_info, **kwargs)
store = MilvusStore(node_groups=self.node_groups, embed=self.embed,
embed_keys=self._activated_embeddings,
fields_desc=self._fields_desc, **kwargs)
else:
raise NotImplementedError(
f"Not implemented store type for {store_type}"
Expand Down Expand Up @@ -216,7 +219,7 @@ def _add_files(self, input_files: List[str]):
if len(input_files) == 0:
return
root_nodes = self._reader.load_data(input_files)
temp_store = self._create_store("map")
temp_store = self._create_store({"type": "map"})
temp_store.update_nodes(root_nodes)
all_groups = self.store.all_groups()
LOG.info(f"add_files: Trying to merge store with {all_groups}")
Expand Down Expand Up @@ -277,9 +280,6 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_
index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]:
self._lazy_init()

if not self.store.is_group_active(group_name):
self.store.add_group(group_name, embed_keys=embed_keys)

if type is None or type == 'default':
return self.store.query(query=query, group_name=group_name, similarity_name=similarity,
similarity_cut_off=similarity_cut_off, topk=topk,
Expand Down
7 changes: 3 additions & 4 deletions lazyllm/tools/rag/doc_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,12 @@ def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group:
self.parent: Optional["DocNode"] = parent
self.children: Dict[str, List["DocNode"]] = defaultdict(list)
self.is_saved: bool = False
self._docpath = None
self._lock = threading.Lock()
self._embedding_state = set()

if fields and parent:
raise ValueError('only ROOT node can set fields.')
self._fields = fields
self._fields = fields if fields else {}

@property
def root_node(self) -> Optional["DocNode"]:
Expand Down Expand Up @@ -76,12 +75,12 @@ def excluded_llm_metadata_keys(self, excluded_llm_metadata_keys: List) -> None:

@property
def docpath(self) -> str:
return self.root_node._docpath or ''
return self.root_node._fields.get('lazyllm_doc_path', '')

@docpath.setter
def docpath(self, path):
assert not self.parent, 'Only root node can set docpath'
self._docpath = str(path)
self._fields['lazyllm_doc_path'] = str(path)

def get_children_str(self) -> str:
return str(
Expand Down
25 changes: 14 additions & 11 deletions lazyllm/tools/rag/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from .doc_node import DocNode
from .store_base import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY
from .utils import DocListManager
from .doc_field_info import DocFieldInfo
from .doc_field_desc import DocFieldDesc
from .web import DocWebModule
import copy
import functools

Expand All @@ -22,9 +23,9 @@ def __call__(self, cls, *args, **kw):
class Document(ModuleBase):
class _Impl(ModuleBase):
def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None,
manager: bool = False, server: bool = False, name: Optional[str] = None,
manager: Union[bool, str] = False, server: bool = False, name: Optional[str] = None,
launcher: Optional[Launcher] = None, store_conf: Optional[Dict] = None,
fields_info: Optional[Dict[str, DocFieldInfo]] = None):
fields_desc: Optional[Dict[str, DocFieldDesc]] = None):
super().__init__()
if not os.path.exists(dataset_path):
defatult_path = os.path.join(lazyllm.config["data_path"], dataset_path)
Expand All @@ -39,19 +40,21 @@ def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str,
self._submodules.append(embed)
self._dlm = DocListManager(dataset_path, name).init_tables()
self._kbs = CallableDict({DocListManager.DEDAULT_GROUP_NAME:
DocImpl(embed=self._embed, dlm=self._dlm, store_conf=store_conf)})
DocImpl(embed=self._embed, dlm=self._dlm, fields_desc=fields_desc,
store_conf=store_conf)})
if manager: self._manager = ServerModule(DocManager(self._dlm))
if manager == 'ui': self._docweb = DocWebModule(doc_server=self._manager)
if server: self._kbs = ServerModule(self._kbs)
self._fields_info = fields_info
self._fields_desc = fields_desc

def add_kb_group(self, name, store_conf: Optional[Dict] = None):
def add_kb_group(self, name, fields_desc: Optional[Dict[str, DocFieldDesc]] = None,
store_conf: Optional[Dict] = None):
if isinstance(self._kbs, ServerModule):
self._kbs._impl._m[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name,
store_conf=store_conf)
fields_desc=fields_desc, store_conf=store_conf)
else:
self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name,
store_conf=store_conf)
fields_desc=fields_desc, store_conf=store_conf)
self._dlm.add_kb_group(name)

def get_doc_by_kb_group(self, name):
Expand All @@ -66,14 +69,14 @@ def __call__(self, *args, **kw):
return self._kbs(*args, **kw)

def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None,
create_ui: bool = False, manager: bool = False, server: bool = False,
create_ui: bool = False, manager: Union[bool, str] = False, server: bool = False,
name: Optional[str] = None, launcher: Optional[Launcher] = None,
fields_info: Dict[str, DocFieldInfo] = None, store_conf: Optional[Dict] = None):
fields_desc: Dict[str, DocFieldDesc] = None, store_conf: Optional[Dict] = None):
super().__init__()
if create_ui:
lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead')
self._impls = Document._Impl(dataset_path, embed, create_ui or manager, server, name,
launcher, store_conf, fields_info)
launcher, store_conf, fields_desc)
self._curr_group = DocListManager.DEDAULT_GROUP_NAME

def create_kb_group(self, name: str, store_conf: Optional[Dict] = None) -> "Document":
Expand Down
10 changes: 4 additions & 6 deletions lazyllm/tools/rag/map_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str],
index.remove(uids, group_name)

class MapStore(StoreBase):
def __init__(self, embed: Dict[str, Callable], **kwargs):
def __init__(self, node_groups: List[str], embed: Dict[str, Callable], **kwargs):
# Dict[group_name, Dict[uuid, DocNode]]
self._group2docs: Dict[str, Dict[str, DocNode]] = {}
self._group2docs: Dict[str, Dict[str, DocNode]] = {
group: {} for group in node_groups
}

self._name2index = {
'default': DefaultIndex(embed, self),
Expand Down Expand Up @@ -72,10 +74,6 @@ def is_group_active(self, name: str) -> bool:
def all_groups(self) -> List[str]:
return self._group2docs.keys()

@override
def add_group(self, name: str, embed_keys: Optional[List[str]] = None) -> None:
self._group2docs.setdefault(name, {})

@override
def query(self, *args, **kwargs) -> List[DocNode]:
return self.get_index('default').query(*args, **kwargs)
Expand Down
Loading
Loading