From b58369f087f3e1fca23231ba7f4ed8c607c0cb81 Mon Sep 17 00:00:00 2001 From: lwj-st Date: Tue, 29 Oct 2024 10:02:50 +0800 Subject: [PATCH 01/60] add pymilvus pkg --- LazyLLM-Env | 2 +- pyproject.toml | 1 + requirements.full.txt | 1 + requirements.txt | 3 ++- 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/LazyLLM-Env b/LazyLLM-Env index 514abd05..80b13f6a 160000 --- a/LazyLLM-Env +++ b/LazyLLM-Env @@ -1 +1 @@ -Subproject commit 514abd053325757cb0fc650adcc4645c848d654f +Subproject commit 80b13f6a8eb049e3712b6d53350da54f4c9286b5 diff --git a/pyproject.toml b/pyproject.toml index 5593baa7..2f2731b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ psutil = "^6.0.0" pypdf = "^5.0.0" pytest = "^8.3.3" numpy = "==1.26.4" +pymilvus = "^2.4.8" redis = { version = ">=5.0.4", optional = true } huggingface-hub = { version = ">=0.23.1", optional = true } pandas = { version = ">=2.2.2", optional = true } diff --git a/requirements.full.txt b/requirements.full.txt index 8c8387f3..732bab12 100644 --- a/requirements.full.txt +++ b/requirements.full.txt @@ -31,6 +31,7 @@ psutil pypdf pytest numpy==1.26.4 +pymilvus redis>=5.0.4 huggingface-hub>=0.23.1 pandas>=2.2.2 diff --git a/requirements.txt b/requirements.txt index b8010c2b..4160246d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,5 @@ sqlalchemy psutil pypdf pytest -numpy==1.26.4 \ No newline at end of file +numpy==1.26.4 +pymilvus From 2fa48d07d16e692b5b62114681e515fe370a94bc Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 29 Oct 2024 10:02:50 +0800 Subject: [PATCH 02/60] store and index api breaking changes --- lazyllm/__init__.py | 4 +- lazyllm/tools/__init__.py | 4 +- lazyllm/tools/rag/__init__.py | 6 +- lazyllm/tools/rag/base_index.py | 30 ++ lazyllm/tools/rag/base_store.py | 112 +++++ lazyllm/tools/rag/component/bm25.py | 2 +- lazyllm/tools/rag/dataReader.py | 2 +- lazyllm/tools/rag/data_loaders.py | 3 +- lazyllm/tools/rag/doc_impl.py | 115 ++++-- lazyllm/tools/rag/doc_node.py | 151 +++++++ lazyllm/tools/rag/document.py | 19 +- lazyllm/tools/rag/index.py | 152 +++++-- lazyllm/tools/rag/rerank.py | 2 +- lazyllm/tools/rag/retriever.py | 2 +- lazyllm/tools/rag/store.py | 385 +++++++----------- lazyllm/tools/rag/transform.py | 2 +- .../standard_test/test_reranker.py | 2 +- tests/basic_tests/test_bm25.py | 2 +- tests/basic_tests/test_doc_node.py | 2 +- tests/basic_tests/test_document.py | 46 ++- tests/basic_tests/test_index.py | 88 +++- tests/basic_tests/test_store.py | 87 +++- tests/basic_tests/test_transform.py | 2 +- tests/requirements.txt | 1 + 24 files changed, 882 insertions(+), 339 deletions(-) create mode 100644 lazyllm/tools/rag/base_index.py create mode 100644 lazyllm/tools/rag/base_store.py create mode 100644 lazyllm/tools/rag/doc_node.py diff --git a/lazyllm/__init__.py b/lazyllm/__init__.py index 5ba10262..15c9125e 100644 --- a/lazyllm/__init__.py +++ b/lazyllm/__init__.py @@ -15,7 +15,7 @@ from .client import redis_client from .tools import (Document, Reranker, Retriever, WebModule, ToolManager, FunctionCall, FunctionCallAgent, fc_register, ReactAgent, PlanAndSolveAgent, ReWOOAgent, SentenceSplitter, - LLMParser) + LLMParser, BaseStore, BaseIndex) from .docs import add_doc config.done() @@ -73,6 +73,8 @@ 'PlanAndSolveAgent', 'ReWOOAgent', 'SentenceSplitter', + 'BaseStore', + 'BaseIndex', # docs 'add_doc', diff --git a/lazyllm/tools/__init__.py b/lazyllm/tools/__init__.py index 0df3c274..52500a5d 100644 --- a/lazyllm/tools/__init__.py +++ b/lazyllm/tools/__init__.py @@ -1,4 +1,4 @@ -from .rag import Document, Reranker, Retriever, SentenceSplitter, LLMParser +from .rag import Document, Reranker, Retriever, SentenceSplitter, LLMParser, BaseStore, BaseIndex from .webpages import WebModule from .agent import ( ToolManager, @@ -32,4 +32,6 @@ "SqlManager", "SqlCall", "HttpTool", + 'BaseStore', + 'BaseIndex', ] diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index 783e4744..74df9e01 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -3,11 +3,13 @@ from .rerank import Reranker, register_reranker from .transform import SentenceSplitter, LLMParser, NodeTransform, TransformArgs, AdaptiveTransform from .index import register_similarity -from .store import DocNode +from .doc_node import DocNode from .readers import (PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader) from .dataReader import SimpleDirectoryReader from .doc_manager import DocManager, DocListManager +from .base_store import BaseStore +from .base_index import BaseIndex __all__ = [ @@ -37,4 +39,6 @@ "SimpleDirectoryReader", 'DocManager', 'DocListManager', + 'BaseStore', + 'BaseIndex', ] diff --git a/lazyllm/tools/rag/base_index.py b/lazyllm/tools/rag/base_index.py new file mode 100644 index 00000000..543d9d15 --- /dev/null +++ b/lazyllm/tools/rag/base_index.py @@ -0,0 +1,30 @@ +from .doc_node import DocNode +from abc import ABC, abstractmethod +from typing import List, Optional + +class BaseIndex(ABC): + @abstractmethod + def update(nodes: List[DocNode]) -> None: + ''' + Inserts or updates a list of `DocNode` to this index. + + Args: + nodes (List[DocNode]): nodes to be inserted or updated. + ''' + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def remove(uids: List[str], group_name: Optional[str] = None) -> None: + ''' + Removes `DocNode`s sepcified by `uids`. If `group_name` is not None, + just remove uids from that group. + + Args: + uids (List[str]): a list of doc ids. + group_name (Optional[str]): name of the group. + ''' + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def query(self, *args, **kwargs) -> List[DocNode]: + raise NotImplementedError("not implemented yet.") diff --git a/lazyllm/tools/rag/base_store.py b/lazyllm/tools/rag/base_store.py new file mode 100644 index 00000000..384a4873 --- /dev/null +++ b/lazyllm/tools/rag/base_store.py @@ -0,0 +1,112 @@ +from abc import ABC, abstractmethod +from typing import Optional, List, Dict +from .doc_node import DocNode +from .base_index import BaseIndex + +class BaseStore(ABC): + @abstractmethod + def update_nodes(self, nodes: List[DocNode]) -> None: + ''' + Inserts or updates a list of `DocNode` to this store. + + Args: + nodes (List[DocNode]): nodes to be inserted or updated. + ''' + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def get_group_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: + ''' + Returns a list of `DocNode` specified by `uids` in the group named `group_name`. + All `DocNode`s in the group `group_name` will be returned if `uids` is `None` or `[]`. + + Args: + group_name (str): the name of group. + uids (List[str]): a list of doc ids. + + Returns: + List[DocNode]: the result. + ''' + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def remove_group_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: + ''' + Removes sepcified `DocNode`s in the group named `group_name`. + Group `group_name` will be removed if `uids` is `None` or `[]`. + + Args: + group_name (str): the name of group. + uids (List[str]): a list of doc ids. + ''' + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def group_is_active(self, group_name: str) -> bool: + ''' + Returns `True` if a group named `group_name` exists or has at least one `DocNode`. + + Args: + group_name (str): the name of group. + + Returns: + bool: whether the group `group_name` is active. + ''' + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def group_names(self) -> List[str]: + ''' + Returns group names in this store. + + Returns: + List[str]: the result. + ''' + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def register_index(self, type_name: str, index: BaseIndex) -> None: + ''' + Registers `index` with type `type` to this store. + + Args: + type_name (str): type of the index to be registered. + index (BaseIndex): the index to be registered. + ''' + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def remove_index(self, type_name: str) -> None: + ''' + Removes index with type `type` in this store. + + Args: + type_name (str): type of the index to be removed. + ''' + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def get_index(self, type_name: str) -> Optional[BaseIndex]: + ''' + Returns index with the specified type `type` in this store. + + Args: + type_name (str): type of the index to be removed. + + Returns: + Optional[BaseIndex]: the index of specified type, or `None`. + ''' + raise NotImplementedError("not implemented yet.") + + # ----- helper functions ----- # + + @staticmethod + def _update_indices(name2index: Dict[str, BaseIndex], nodes: List[DocNode]) -> None: + for _, index in name2index.items(): + index.update(nodes) + + @staticmethod + def _remove_from_indices(name2index: Dict[str, BaseIndex], uids: List[str], + group_name: Optional[str] = None) -> None: + for _, index in name2index.items(): + index.remove(uids, group_name) diff --git a/lazyllm/tools/rag/component/bm25.py b/lazyllm/tools/rag/component/bm25.py index 171c5d97..56881869 100644 --- a/lazyllm/tools/rag/component/bm25.py +++ b/lazyllm/tools/rag/component/bm25.py @@ -1,5 +1,5 @@ from typing import List, Tuple -from ..store import DocNode +from ..doc_node import DocNode import bm25s import Stemmer from lazyllm.thirdparty import jieba diff --git a/lazyllm/tools/rag/dataReader.py b/lazyllm/tools/rag/dataReader.py index 116f3559..319525e0 100644 --- a/lazyllm/tools/rag/dataReader.py +++ b/lazyllm/tools/rag/dataReader.py @@ -14,7 +14,7 @@ from pathlib import Path, PurePosixPath, PurePath from fsspec import AbstractFileSystem from lazyllm import ModuleBase, LOG -from .store import DocNode +from .doc_node import DocNode from .readers import (ReaderBase, PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader, get_default_fs, is_default_fs) diff --git a/lazyllm/tools/rag/data_loaders.py b/lazyllm/tools/rag/data_loaders.py index 02a11c9a..0212fc17 100644 --- a/lazyllm/tools/rag/data_loaders.py +++ b/lazyllm/tools/rag/data_loaders.py @@ -1,5 +1,6 @@ from typing import List, Optional, Dict -from .store import DocNode, LAZY_ROOT_NAME +from .doc_node import DocNode +from .store import LAZY_ROOT_NAME from lazyllm import LOG from .dataReader import SimpleDirectoryReader diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index bf88adab..e71e585e 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -5,15 +5,41 @@ from lazyllm import LOG, config, once_wrapper from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser, AdaptiveTransform, make_transform, TransformArgs) -from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, BaseStore +from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, BaseStore, StoreWrapper from .data_loaders import DirectoryReader -from .index import DefaultIndex +from .index import DefaultIndex, BaseIndex from .utils import DocListManager import threading import time _transmap = dict(function=FuncNodeTransform, sentencesplitter=SentenceSplitter, llm=LLMParser) +class FileNodeIndex(BaseIndex): + def __init__(self): + self._file_node_map = {} + + # override + def update(self, nodes: List[DocNode]) -> None: + for node in nodes: + if node.group != LAZY_ROOT_NAME: + continue + file_name = node.metadata.get("file_name") + if file_name: + self._file_node_map[file_name] = node + + # override + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + # group_name is ignored + left = {k: v for k, v in self._file_node_map.items() if v.uid not in uids} + self._file_node_map = left + + # override + def query(self, files: List[str]) -> List[DocNode]: + ret = [] + for file in files: + ret.append(self._file_node_map.get(file)) + return ret + def embed_wrapper(func): if not func: @@ -33,7 +59,8 @@ class DocImpl: _registered_file_reader: Dict[str, Callable] = {} def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = None, - doc_files: Optional[str] = None, kb_group_name: str = None): + doc_files: Optional[str] = None, kb_group_name: Optional[str] = None, + store: Optional[BaseStore] = None): super().__init__() assert (dlm is None) ^ (doc_files is None), 'Only one of dataset_path or doc_files should be provided' self._local_file_reader: Dict[str, Callable] = {} @@ -43,7 +70,18 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N self.node_groups: Dict[str, Dict] = {LAZY_ROOT_NAME: {}} self.embed = {k: embed_wrapper(e) for k, e in embed.items()} self._embed_dim = None - self.store = None + if store: + self.store = StoreWrapper(store) + self._create_some_indices_for_store(self.store) + else: + self.store = None + + @staticmethod + def _create_file_node_index(store) -> FileNodeIndex: + index = FileNodeIndex() + for group in store.group_names(): + index.update(store.get_group_nodes(group)) + return index @once_wrapper(reset_on_pickle=True) def _lazy_init(self) -> None: @@ -54,12 +92,14 @@ def _lazy_init(self) -> None: self._embed_dim = {k: len(e('a')) for k, e in self.embed.items()} - self.store = self._get_store() - self.index = DefaultIndex(self.embed, self.store) - if not self.store.has_nodes(LAZY_ROOT_NAME): + if not self.store: + self.store = self._create_store() + self._create_some_indices_for_store(self.store) + + if not self.store.group_is_active(LAZY_ROOT_NAME): ids, pathes = self._list_files() root_nodes = self._reader.load_data(pathes) - self.store.add_nodes(root_nodes) + self.store.update_nodes(root_nodes) if self._dlm: self._dlm.update_kb_group_file_status( ids, DocListManager.Status.success, group=self._kb_group_name) LOG.debug(f"building {LAZY_ROOT_NAME} nodes: {root_nodes}") @@ -69,19 +109,25 @@ def _lazy_init(self) -> None: self._daemon.daemon = True self._daemon.start() - def _get_store(self) -> BaseStore: - rag_store_type = config["rag_store_type"] + def _create_store(self, rag_store_type: str = None) -> BaseStore: + if not rag_store_type: + rag_store_type = config["rag_store_type"] if rag_store_type == "map": store = MapStore(node_groups=self.node_groups.keys()) elif rag_store_type == "chroma": store = ChromadbStore(node_groups=self.node_groups.keys(), embed_dim=self._embed_dim) - store.try_load_store() else: raise NotImplementedError( f"Not implemented store type for {rag_store_type}" ) return store + def _create_some_indices_for_store(self, store: BaseStore): + if not store.get_index(type_name='default'): + store.register_index(type_name='default', index=DefaultIndex(self.embed, store)) + if not store.get_index(type_name='file_node_map'): + store.register_index(type_name='file_node_map', index=self._create_file_node_index(store)) + @staticmethod def _create_node_group_impl(cls, group_name, name, transform: Union[str, Callable] = None, parent: str = LAZY_ROOT_NAME, *, trans_node: bool = None, @@ -184,45 +230,47 @@ def _add_files(self, input_files: List[str]): return self._lazy_init() root_nodes = self._reader.load_data(input_files) - temp_store = self._get_store() - temp_store.add_nodes(root_nodes) - active_groups = self.store.active_groups() - LOG.info(f"add_files: Trying to merge store with {active_groups}") - for group in active_groups: + temp_store = self._create_store("map") + temp_store.update_nodes(root_nodes) + group_names = self.store.group_names() + LOG.info(f"add_files: Trying to merge store with {group_names}") + for group in group_names: + if not self.store.group_is_active(group): + continue # Duplicate group will be discarded automatically nodes = self._get_nodes(group, temp_store) - self.store.add_nodes(nodes) + self.store.update_nodes(nodes) LOG.debug(f"Merge {group} with {nodes}") def _delete_files(self, input_files: List[str]) -> None: self._lazy_init() - docs = self.store.get_nodes_by_files(input_files) + docs = self.store.get_index(type_name='file_node_map').query(input_files) LOG.info(f"delete_files: removing documents {input_files} and nodes {docs}") if len(docs) == 0: return self._delete_nodes_recursively(docs) def _delete_nodes_recursively(self, root_nodes: List[DocNode]) -> None: - nodes_to_delete = defaultdict(list) - nodes_to_delete[LAZY_ROOT_NAME] = root_nodes + uids_to_delete = defaultdict(list) + uids_to_delete[LAZY_ROOT_NAME] = [node.uid for node in root_nodes] # Gather all nodes to be deleted including their children def gather_children(node: DocNode): for children_group, children_list in node.children.items(): for child in children_list: - nodes_to_delete[children_group].append(child) + uids_to_delete[children_group].append(child.uid) gather_children(child) for node in root_nodes: gather_children(node) # Delete nodes in all groups - for group, node_uids in nodes_to_delete.items(): - self.store.remove_nodes(node_uids) + for group, node_uids in uids_to_delete.items(): + self.store.remove_group_nodes(group, node_uids) LOG.debug(f"Removed nodes from group {group} for node IDs: {node_uids}") def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None: - if store.has_nodes(group_name): + if store.group_is_active(group_name): return node_group = self.node_groups.get(group_name) if node_group is None: @@ -232,23 +280,26 @@ def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None: transform = AdaptiveTransform(t) if isinstance(t, list) or t.pattern else make_transform(t) parent_nodes = self._get_nodes(node_group["parent"], store) nodes = transform.batch_forward(parent_nodes, group_name) - store.add_nodes(nodes) + store.update_nodes(nodes) LOG.debug(f"building {group_name} nodes: {nodes}") def _get_nodes(self, group_name: str, store: Optional[BaseStore] = None) -> List[DocNode]: store = store or self.store self._dynamic_create_nodes(group_name, store) - return store.traverse_nodes(group_name) + return store.get_group_nodes(group_name) def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_off: Union[float, Dict[str, float]], index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]: self._lazy_init() - if index: - assert index == "default", "we only support default index currently" - nodes = self._get_nodes(group_name) - return self.index.query( - query, nodes, similarity, similarity_cut_off, topk, embed_keys, **similarity_kws - ) + + index_instance = self.store.get_index(type_name=index) + if not index_instance: + raise NotImplementedError(f"index type '{index}' is not supported currently.") + + self._dynamic_create_nodes(group_name, self.store) + return index_instance.query(query=query, group_name=group_name, similarity_name=similarity, + similarity_cut_off=similarity_cut_off, topk=topk, + embed_keys=embed_keys, **similarity_kws) @staticmethod def find_parent(nodes: List[DocNode], group: str) -> List[DocNode]: diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py new file mode 100644 index 00000000..919cde5d --- /dev/null +++ b/lazyllm/tools/rag/doc_node.py @@ -0,0 +1,151 @@ +from typing import Optional, Dict, Any, Union, Callable, List +from enum import Enum, auto +from collections import defaultdict +from lazyllm import config +import uuid +import threading +import time + +class MetadataMode(str, Enum): + ALL = auto() + EMBED = auto() + LLM = auto() + NONE = auto() + + +class DocNode: + def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: Optional[str] = None, + embedding: Optional[Dict[str, List[float]]] = None, parent: Optional["DocNode"] = None, + metadata: Optional[Dict[str, Any]] = None, classfication: Optional[str] = None): + self.uid: str = uid if uid else str(uuid.uuid4()) + self.text: Optional[str] = text + self.group: Optional[str] = group + self.embedding: Optional[Dict[str, List[float]]] = embedding or None + self._metadata: Dict[str, Any] = metadata or {} + # Metadata keys that are excluded from text for the embed model. + self._excluded_embed_metadata_keys: List[str] = [] + # Metadata keys that are excluded from text for the LLM. + self._excluded_llm_metadata_keys: List[str] = [] + self.parent: Optional["DocNode"] = parent + self.children: Dict[str, List["DocNode"]] = defaultdict(list) + self.is_saved: bool = False + self._docpath = None + self._lock = threading.Lock() + self._embedding_state = set() + # store will create index cache for classfication to speed up retrieve + self._classfication = classfication + + @property + def root_node(self) -> Optional["DocNode"]: + root = self.parent + while root and root.parent: + root = root.parent + return root or self + + @property + def metadata(self) -> Dict: + return self.root_node._metadata + + @metadata.setter + def metadata(self, metadata: Dict) -> None: + self._metadata = metadata + + @property + def excluded_embed_metadata_keys(self) -> List: + return self.root_node._excluded_embed_metadata_keys + + @excluded_embed_metadata_keys.setter + def excluded_embed_metadata_keys(self, excluded_embed_metadata_keys: List) -> None: + self._excluded_embed_metadata_keys = excluded_embed_metadata_keys + + @property + def excluded_llm_metadata_keys(self) -> List: + return self.root_node._excluded_llm_metadata_keys + + @excluded_llm_metadata_keys.setter + def excluded_llm_metadata_keys(self, excluded_llm_metadata_keys: List) -> None: + self._excluded_llm_metadata_keys = excluded_llm_metadata_keys + + @property + def docpath(self) -> str: + return self.root_node._docpath or '' + + @docpath.setter + def docpath(self, path): + assert not self.parent, 'Only root node can set docpath' + self._docpath = str(path) + + def get_children_str(self) -> str: + return str( + {key: [node.uid for node in nodes] for key, nodes in self.children.items()} + ) + + def get_parent_id(self) -> str: + return self.parent.uid if self.parent else "" + + def __str__(self) -> str: + return ( + f"DocNode(id: {self.uid}, group: {self.group}, text: {self.get_text()}) parent: {self.get_parent_id()}, " + f"children: {self.get_children_str()}" + ) + + def __repr__(self) -> str: + return str(self) if config["debug"] else f'' + + def __eq__(self, other): + if isinstance(other, DocNode): + return self.uid == other.uid + return False + + def __hash__(self): + return hash(self.uid) + + def has_missing_embedding(self, embed_keys: Union[str, List[str]]) -> List[str]: + if isinstance(embed_keys, str): embed_keys = [embed_keys] + assert len(embed_keys) > 0, "The ebmed_keys to be checked must be passed in." + if self.embedding is None: return embed_keys + return [k for k in embed_keys if k not in self.embedding] + + def do_embedding(self, embed: Dict[str, Callable]) -> None: + generate_embed = {k: e(self.get_text(MetadataMode.EMBED)) for k, e in embed.items()} + with self._lock: + self.embedding = self.embedding or {} + self.embedding = {**self.embedding, **generate_embed} + self.is_saved = False + + def check_embedding_state(self, embed_key: str) -> None: + while True: + with self._lock: + if not self.has_missing_embedding(embed_key): + self._embedding_state.discard(embed_key) + break + time.sleep(1) + + def get_content(self) -> str: + return self.get_text(MetadataMode.LLM) + + def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: + """Metadata info string.""" + if mode == MetadataMode.NONE: + return "" + + metadata_keys = set(self.metadata.keys()) + if mode == MetadataMode.LLM: + for key in self.excluded_llm_metadata_keys: + if key in metadata_keys: + metadata_keys.remove(key) + elif mode == MetadataMode.EMBED: + for key in self.excluded_embed_metadata_keys: + if key in metadata_keys: + metadata_keys.remove(key) + + return "\n".join([f"{key}: {self.metadata[key]}" for key in metadata_keys]) + + def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: + metadata_str = self.get_metadata_str(metadata_mode).strip() + if not metadata_str: + return self.text if self.text else "" + return f"{metadata_str}\n\n{self.text}".strip() + + def to_dict(self) -> Dict: + return dict(text=self.text, embedding=self.embedding, metadata=self.metadata) diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index e8345e5d..1e08e912 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -6,7 +6,7 @@ from .doc_manager import DocManager from .doc_impl import DocImpl -from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY, DocNode +from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY, DocNode, BaseStore from .utils import DocListManager import copy import functools @@ -15,7 +15,8 @@ class Document(ModuleBase): class _Impl(ModuleBase): def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, - manager: bool = False, server: bool = False, name: Optional[str] = None, launcher=None): + manager: bool = False, server: bool = False, name: Optional[str] = None, launcher=None, + store: BaseStore = None): super().__init__() if not os.path.exists(dataset_path): defatult_path = os.path.join(lazyllm.config["data_path"], dataset_path) @@ -29,27 +30,27 @@ def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, if isinstance(embed, ModuleBase): self._submodules.append(embed) self._dlm = DocListManager(dataset_path, name).init_tables() - self._kbs = {DocListManager.DEDAULT_GROUP_NAME: DocImpl(embed=self._embed, dlm=self._dlm)} + self._kbs = {DocListManager.DEDAULT_GROUP_NAME: DocImpl(embed=self._embed, dlm=self._dlm, store=store)} if manager: self._manager = DocManager(self._dlm) if server: self._doc = ServerModule(self._doc) - def add_kb_group(self, name): - self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name) + def add_kb_group(self, name, store: BaseStore): + self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name, store=store) self._dlm.add_kb_group(name) def get_doc_by_kb_group(self, name): return self._kbs[name] def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, create_ui: bool = False, manager: bool = False, server: bool = False, - name: Optional[str] = None, launcher=None): + name: Optional[str] = None, launcher=None, store: BaseStore = None): super().__init__() if create_ui: lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead') - self._impls = Document._Impl(dataset_path, embed, create_ui or manager, server, name, launcher) + self._impls = Document._Impl(dataset_path, embed, create_ui or manager, server, name, launcher, store) self._curr_group = DocListManager.DEDAULT_GROUP_NAME - def create_kb_group(self, name: str) -> "Document": - self._impls.add_kb_group(name) + def create_kb_group(self, name: str, store: BaseStore) -> "Document": + self._impls.add_kb_group(name, store) doc = copy.copy(self) doc._curr_group = name return doc diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index d9f6ad24..c9454f17 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -1,10 +1,15 @@ import concurrent import os from typing import List, Callable, Optional, Dict, Union, Tuple -from .store import DocNode, BaseStore +from .doc_node import DocNode +from .base_store import BaseStore +from .base_index import BaseIndex import numpy as np from .component.bm25 import BM25 from lazyllm import LOG, config, ThreadPoolExecutor +import pymilvus + +# ---------------------------------------------------------------------------- # # min(32, (os.cpu_count() or 1) + 4) is the default number of workers for ThreadPoolExecutor config.add( @@ -14,8 +19,33 @@ "MAX_EMBEDDING_WORKERS", ) +# ---------------------------------------------------------------------------- # + +def parallel_do_embedding(embed: Dict[str, Callable], nodes: List[DocNode]) -> List[DocNode]: + ''' + returns a list of modified nodes + ''' + modified_nodes = [] + with ThreadPoolExecutor(config["max_embedding_workers"]) as executor: + futures = [] + for node in nodes: + miss_keys = node.has_missing_embedding(embed.keys()) + if not miss_keys: + continue + modified_nodes.append(node) + for k in miss_keys: + with node._lock: + if node.has_missing_embedding(k): + future = executor.submit(node.do_embedding, {k: embed[k]}) \ + if k not in node._embedding_state else executor.submit(node.check_embedding_state, k) + node._embedding_state.add(k) + futures.append(future) + if len(futures) > 0: + for future in concurrent.futures.as_completed(futures): + future.result() + return modified_nodes -class DefaultIndex: +class DefaultIndex(BaseIndex): """Default Index, registered for similarity functions""" registered_similarity = dict() @@ -55,29 +85,19 @@ def wrapper(query, nodes, **kwargs): return decorator(func) if func else decorator - def _parallel_do_embedding(self, nodes: List[DocNode]) -> List[DocNode]: - with ThreadPoolExecutor(config["max_embedding_workers"]) as executor: - futures = [] - for node in nodes: - miss_keys = node.has_missing_embedding(self.embed.keys()) - if not miss_keys: - continue - for k in miss_keys: - with node._lock: - if node.has_missing_embedding(k): - future = executor.submit(node.do_embedding, {k: self.embed[k]}) \ - if k not in node._embedding_state else executor.submit(node.check_embedding_state, k) - node._embedding_state.add(k) - futures.append(future) - if len(futures) > 0: - for future in concurrent.futures.as_completed(futures): - future.result() - return nodes + # override + def update(self, nodes: List[DocNode]) -> None: + pass + + # override + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + pass + # override def query( self, query: str, - nodes: List[DocNode], + group_name: str, similarity_name: str, similarity_cut_off: Union[float, Dict[str, float]], topk: int, @@ -91,12 +111,13 @@ def query( ) similarity_func, mode, descend = self.registered_similarity[similarity_name] + nodes = self.store.get_group_nodes(group_name) if mode == "embedding": assert self.embed, "Chosen similarity needs embed model." assert len(query) > 0, "Query should not be empty." query_embedding = {k: self.embed[k](query) for k in (embed_keys or self.embed.keys())} - nodes = self._parallel_do_embedding(nodes) - self.store.try_save_nodes(nodes) + modified_nodes = parallel_do_embedding(self.embed, nodes) + self.store.update_nodes(modified_nodes) similarities = similarity_func(query_embedding, nodes, topk=topk, **kwargs) elif mode == "text": similarities = similarity_func(query, nodes, topk=topk, **kwargs) @@ -150,3 +171,88 @@ def register_similarity( batch: bool = False, ) -> Callable: return DefaultIndex.register_similarity(func, mode, descend, batch) + +# ---------------------------------------------------------------------------- # + +class MilvusIndex(BaseIndex): + class Field: + def __init__(self, name: str, data_type: pymilvus.DataType, index_type: str, + metric_type: str, index_params={}, dim: Optional[int] = None): + self.name = name + self.data_type = data_type + self.index_type = index_type + self.metric_type = metric_type + self.index_params = index_params + self.dim = dim + + def __init__(self, embed: Dict[str, Callable], + group_fields: Dict[str, List[MilvusIndex.Field]], + uri: str, full_data_store: BaseStore): + self._embed = embed + self._full_data_store = full_data_store + + self._primary_key = 'uid' + self._client = pymilvus.MilvusClient(uri=uri) + + for group_name, field_list in group_fields.items(): + if group_name in self._client.list_collections(): + continue + + schema = self._client.create_schema(auto_id=False, enable_dynamic_field=False) + schema.add_field( + field_name=self._primary_key, + datatype=pymilvus.DataType.VARCHAR, + max_length=128, + is_primary=True, + ) + for field in field_list: + schema.add_field( + field_name=field.name, + datatype=field.data_type, + dim=field.dim) + + index_params = self._client.prepare_index_params() + for field in field_list: + index_params.add_index(field_name=field.name, index_type=field.index_type, + metric_type=field.metric_type, params=field.index_params) + + self._client.create_collection(collection_name=group_name, schema=schema, + index_params=index_params) + + # override + def update(self, nodes: List[DocNode]) -> None: + parallel_do_embedding(self._embed, nodes) + for node in nodes: + data = node.embedding.copy() + data[self._primary_key] = node.uid + self._client.upsert(collection_name=node.group, data=data) + + # override + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + if group_name: + self._client.delete(collection_name=group_name, + filter=f'{self._primary_key} in {uids}') + else: + for group_name in self._client.list_collections(): + self._client.delete(collection_name=group_name, + filter=f'{self._primary_key} in {uids}') + + # override + def query(self, + query: str, + group_name: str, + embed_keys: Optional[List[str]] = None, + topk: int = 10, + **kwargs) -> List[DocNode]: + uids = set() + for embed_name in embed_keys: + embed_func = self._embed.get(embed_name) + query_embedding = embed_func(query) + results = self._client.search(collection_name=group_name, data=[query_embedding], + limit=topk, anns_field=embed_name) + if len(results) > 0: + # we have only one `data` for search() so there is only one result in `results` + for result in results[0]: + uids.update(result['id']) + + return self._full_data_store.get_group_nodes(group_name, list(uids)) diff --git a/lazyllm/tools/rag/rerank.py b/lazyllm/tools/rag/rerank.py index 875243af..43b97623 100644 --- a/lazyllm/tools/rag/rerank.py +++ b/lazyllm/tools/rag/rerank.py @@ -3,7 +3,7 @@ import lazyllm from lazyllm import ModuleBase, LOG -from lazyllm.tools.rag.store import DocNode, MetadataMode +from .doc_node import DocNode, MetadataMode from .retriever import _PostProcess diff --git a/lazyllm/tools/rag/retriever.py b/lazyllm/tools/rag/retriever.py index b22f6cdb..0acf2883 100644 --- a/lazyllm/tools/rag/retriever.py +++ b/lazyllm/tools/rag/retriever.py @@ -1,5 +1,5 @@ from lazyllm import ModuleBase, pipeline, once_wrapper -from .store import DocNode +from .doc_node import DocNode from .document import Document, DocImpl from typing import List, Optional, Union, Dict diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 702605f0..a59c60f8 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -1,247 +1,136 @@ -from abc import ABC, abstractmethod -from collections import defaultdict -from enum import Enum, auto -import uuid -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import chromadb from lazyllm import LOG, config from chromadb.api.models.Collection import Collection -import threading +from .base_store import BaseStore +from .base_index import BaseIndex +from .doc_node import DocNode import json -import time +# ---------------------------------------------------------------------------- # LAZY_ROOT_NAME = "lazyllm_root" EMBED_DEFAULT_KEY = '__default__' config.add("rag_store_type", str, "map", "RAG_STORE_TYPE") # "map", "chroma" config.add("rag_persistent_path", str, "./lazyllm_chroma", "RAG_PERSISTENT_PATH") +# ---------------------------------------------------------------------------- # -class MetadataMode(str, Enum): - ALL = auto() - EMBED = auto() - LLM = auto() - NONE = auto() - - -class DocNode: - def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: Optional[str] = None, - embedding: Optional[Dict[str, List[float]]] = None, parent: Optional["DocNode"] = None, - metadata: Optional[Dict[str, Any]] = None, classfication: Optional[str] = None): - self.uid: str = uid if uid else str(uuid.uuid4()) - self.text: Optional[str] = text - self.group: Optional[str] = group - self.embedding: Optional[Dict[str, List[float]]] = embedding or None - self._metadata: Dict[str, Any] = metadata or {} - # Metadata keys that are excluded from text for the embed model. - self._excluded_embed_metadata_keys: List[str] = [] - # Metadata keys that are excluded from text for the LLM. - self._excluded_llm_metadata_keys: List[str] = [] - self.parent: Optional["DocNode"] = parent - self.children: Dict[str, List["DocNode"]] = defaultdict(list) - self.is_saved: bool = False - self._docpath = None - self._lock = threading.Lock() - self._embedding_state = set() - # store will create index cache for classfication to speed up retrieve - self._classfication = classfication - - @property - def root_node(self) -> Optional["DocNode"]: - root = self.parent - while root and root.parent: - root = root.parent - return root or self - - @property - def metadata(self) -> Dict: - return self.root_node._metadata - - @metadata.setter - def metadata(self, metadata: Dict) -> None: - self._metadata = metadata - - @property - def excluded_embed_metadata_keys(self) -> List: - return self.root_node._excluded_embed_metadata_keys - - @excluded_embed_metadata_keys.setter - def excluded_embed_metadata_keys(self, excluded_embed_metadata_keys: List) -> None: - self._excluded_embed_metadata_keys = excluded_embed_metadata_keys - - @property - def excluded_llm_metadata_keys(self) -> List: - return self.root_node._excluded_llm_metadata_keys - - @excluded_llm_metadata_keys.setter - def excluded_llm_metadata_keys(self, excluded_llm_metadata_keys: List) -> None: - self._excluded_llm_metadata_keys = excluded_llm_metadata_keys - - @property - def docpath(self) -> str: - return self.root_node._docpath or '' - - @docpath.setter - def docpath(self, path): - assert not self.parent, 'Only root node can set docpath' - self._docpath = str(path) - - def get_children_str(self) -> str: - return str( - {key: [node.uid for node in nodes] for key, nodes in self.children.items()} - ) - - def get_parent_id(self) -> str: - return self.parent.uid if self.parent else "" - - def __str__(self) -> str: - return ( - f"DocNode(id: {self.uid}, group: {self.group}, text: {self.get_text()}) parent: {self.get_parent_id()}, " - f"children: {self.get_children_str()}" - ) - - def __repr__(self) -> str: - return str(self) if config["debug"] else f'' - - def __eq__(self, other): - if isinstance(other, DocNode): - return self.uid == other.uid - return False - - def __hash__(self): - return hash(self.uid) - - def has_missing_embedding(self, embed_keys: Union[str, List[str]]) -> List[str]: - if isinstance(embed_keys, str): embed_keys = [embed_keys] - assert len(embed_keys) > 0, "The ebmed_keys to be checked must be passed in." - if self.embedding is None: return embed_keys - return [k for k in embed_keys if k not in self.embedding] - - def do_embedding(self, embed: Dict[str, Callable]) -> None: - generate_embed = {k: e(self.get_text(MetadataMode.EMBED)) for k, e in embed.items()} - with self._lock: - self.embedding = self.embedding or {} - self.embedding = {**self.embedding, **generate_embed} - self.is_saved = False - - def check_embedding_state(self, embed_key: str) -> None: - while True: - with self._lock: - if not self.has_missing_embedding(embed_key): - self._embedding_state.discard(embed_key) - break - time.sleep(1) - - def get_content(self) -> str: - return self.get_text(MetadataMode.LLM) - - def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: - """Metadata info string.""" - if mode == MetadataMode.NONE: - return "" - - metadata_keys = set(self.metadata.keys()) - if mode == MetadataMode.LLM: - for key in self.excluded_llm_metadata_keys: - if key in metadata_keys: - metadata_keys.remove(key) - elif mode == MetadataMode.EMBED: - for key in self.excluded_embed_metadata_keys: - if key in metadata_keys: - metadata_keys.remove(key) - - return "\n".join([f"{key}: {self.metadata[key]}" for key in metadata_keys]) - - def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: - metadata_str = self.get_metadata_str(metadata_mode).strip() - if not metadata_str: - return self.text if self.text else "" - return f"{metadata_str}\n\n{self.text}".strip() - - def to_dict(self) -> Dict: - return dict(text=self.text, embedding=self.embedding, metadata=self.metadata) - - -class BaseStore(ABC): - def __init__(self, node_groups: List[str]) -> None: - self._store: Dict[str, Dict[str, DocNode]] = { - group: {} for group in node_groups - } - self._file_node_map = {} - - def _add_nodes(self, nodes: List[DocNode]) -> None: - for node in nodes: - if node.group == LAZY_ROOT_NAME and "file_name" in node.metadata: - self._file_node_map[node.metadata["file_name"]] = node - self._store[node.group][node.uid] = node +class StoreWrapper(BaseStore): + def __init__(self, store: BaseStore): + self._store = store + self._name2index = {} - def add_nodes(self, nodes: List[DocNode]) -> None: - self._add_nodes(nodes) - self.try_save_nodes(nodes) + def update_nodes(self, nodes: List[DocNode]) -> None: + self._store.update_nodes(nodes) + self._update_indices(self._name2index, nodes) - def has_nodes(self, group: str) -> bool: - return len(self._store[group]) > 0 + def get_group_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: + return self._store.get_group_nodes(group_name, uids) - def get_node(self, group: str, node_id: str) -> Optional[DocNode]: - return self._store.get(group, {}).get(node_id) + def remove_group_nodes(self, group_name: str, uids: List[str] = None) -> None: + self._store.remove_group_nodes(group_name, uids) + self._remove_from_indices(self._name2index, uids, group_name) - def traverse_nodes(self, group: str) -> List[DocNode]: - return list(self._store.get(group, {}).values()) + def group_is_active(self, group_name: str) -> bool: + return self._store.group_is_active(group_name) - @abstractmethod - def try_save_nodes(self, nodes: List[DocNode]) -> None: - # try save nodes to persistent source - raise NotImplementedError("Not implemented yet.") + def group_names(self) -> List[str]: + return self._store.group_names() - @abstractmethod - def try_load_store(self) -> None: - # try load nodes from persistent source - raise NotImplementedError("Not implemented yet.") + def register_index(self, type_name: str, index: BaseIndex) -> None: + self._name2index[type_name] = index - @abstractmethod - def try_remove_nodes(self, nodes: List[DocNode]) -> None: - # try remove nodes in persistent source - raise NotImplementedError("Not implemented yet.") + def remove_index(self, type_name: str) -> None: + self._name2index.pop(type_name, None) - def active_groups(self) -> List: - return [group for group, nodes in self._store.items() if nodes] - - def _remove_nodes(self, nodes: List[DocNode]) -> None: - for node in nodes: - assert node.group in self._store, f"Unexpected node group {node.group}" - self._store[node.group].pop(node.uid, None) - - def remove_nodes(self, nodes: List[DocNode]) -> None: - self._remove_nodes(nodes) - self.try_remove_nodes(nodes) - - def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: - nodes = [] - for file in files: - if file in self._file_node_map: - nodes.append(self._file_node_map[file]) - return nodes + def get_index(self, type_name: str) -> Optional[BaseIndex]: + index = self._store.get_index(type_name) + if not index: + index = self._name2index.get(type_name) + return index +# ---------------------------------------------------------------------------- # class MapStore(BaseStore): - def __init__(self, node_groups: List[str], *args, **kwargs): - super().__init__(node_groups, *args, **kwargs) - - def try_save_nodes(self, nodes: List[DocNode]) -> None: - pass - - def try_load_store(self) -> None: - pass - - def try_remove_nodes(self, nodes: List[DocNode]) -> None: - pass + def __init__(self, node_groups: List[str]): + # Dict[group_name, Dict[uuid, DocNode]] + self._group2docs: Dict[str, Dict[str, DocNode]] = { + group: {} for group in node_groups + } + self._name2index = {} + # override + def update_nodes(self, nodes: List[DocNode]) -> None: + for node in nodes: + self._group2docs[node.group][node.uid] = node + + self._update_indices(self._name2index, nodes) + + # override + def get_group_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: + docs = self._group2docs.get(group_name) + if not docs: + return [] + + if not uids: + return list(docs.values()) + + ret = [] + for uid in uids: + doc = docs.get(uid) + if doc: + ret.append(doc) + return ret + + # override + def remove_group_nodes(self, group_name: str, uids: List[str] = None) -> None: + if uids: + docs = self._group2docs.get(group_name) + if docs: + self._remove_from_indices(self._name2index, uids) + for uid in uids: + docs.pop(uid, None) + else: + docs = self._group2docs.pop(group_name, None) + if docs: + self._remove_from_indices(self._name2index, [doc.uid for doc in docs]) + + # override + def group_is_active(self, group_name: str) -> bool: + docs = self._group2docs.get(group_name) + return True if docs else False + + # override + def group_names(self) -> List[str]: + return self._group2docs.keys() + + # override + def register_index(self, type_name: str, index: BaseIndex) -> None: + self._name2index[type_name] = index + + # override + def remove_index(self, type_name: str) -> None: + self._name2index.pop(type_name, None) + + # override + def get_index(self, type_name: str) -> Optional[BaseIndex]: + return self._name2index.get(type_name) + + def find_node_by_uid(self, uid: str) -> Optional[DocNode]: + for docs in self._group2docs.values(): + doc = docs.get(uid) + if doc: + return doc + return None + +# ---------------------------------------------------------------------------- # class ChromadbStore(BaseStore): def __init__( - self, node_groups: List[str], embed_dim: Dict[str, int], *args, **kwargs + self, node_groups: List[str], embed_dim: Dict[str, int] ) -> None: - super().__init__(node_groups, *args, **kwargs) + self._map_store = MapStore(node_groups) self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") self._collections: Dict[str, Collection] = { @@ -250,7 +139,44 @@ def __init__( } self._embed_dim = embed_dim - def try_load_store(self) -> None: + # override + def update_nodes(self, nodes: List[DocNode]) -> None: + self._map_store.update_nodes(nodes) + self._save_nodes(nodes) + + # override + def get_group_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: + return self._map_store.get_group_nodes(group_name, uids) + + # override + def remove_group_nodes(self, group_name: str, uids: List[str]) -> None: + if uids: + self._delete_group_nodes(group_name, uids) + else: + self._db_client.delete_collection(name=group_name) + return self._map_store.remove_group_nodes(group_name, uids) + + # override + def group_is_active(self, group_name: str) -> bool: + return self._map_store.group_is_active(group_name) + + # override + def group_names(self) -> List[str]: + return self._map_store.group_names() + + # override + def register_index(self, type_name: str, index: BaseIndex) -> None: + self._map_store.register_index(type_name, index) + + # override + def remove_index(self, type_name: str) -> Optional[BaseIndex]: + return self._map_store.remove_index(type_name) + + # override + def get_index(self, type_name: str) -> Optional[BaseIndex]: + return self._map_store.get_index(type_name) + + def _load_store(self) -> None: if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: LOG.info("No persistent data found, skip the rebuilding phrase.") return @@ -259,20 +185,21 @@ def try_load_store(self) -> None: for group in self._collections.keys(): results = self._peek_all_documents(group) nodes = self._build_nodes_from_chroma(results) - self._add_nodes(nodes) + self._map_store.update_nodes(nodes) # Rebuild relationships - for group, nodes_dict in self._store.items(): - for node in nodes_dict.values(): + for group_name in self._map_store.group_names(): + nodes = self._map_store.get_group_nodes(group_name) + for node in nodes: if node.parent: parent_uid = node.parent - parent_node = self._find_node_by_uid(parent_uid) + parent_node = self._map_store.find_node_by_uid(parent_uid) node.parent = parent_node parent_node.children[node.group].append(node) - LOG.debug(f"build {group} nodes from chromadb: {nodes_dict.values()}") + LOG.debug(f"build {group} nodes from chromadb: {nodes}") LOG.success("Successfully Built nodes from chromadb.") - def try_save_nodes(self, nodes: List[DocNode]) -> None: + def _save_nodes(self, nodes: List[DocNode]) -> None: if not nodes: return # Note: It's caller's duty to make sure this batch of nodes has the same group. @@ -301,14 +228,10 @@ def try_save_nodes(self, nodes: List[DocNode]) -> None: ) LOG.debug(f"Saved {group} nodes {ids} to chromadb.") - def try_remove_nodes(self, nodes: List[DocNode]) -> None: - pass - - def _find_node_by_uid(self, uid: str) -> Optional[DocNode]: - for nodes_by_category in self._store.values(): - if uid in nodes_by_category: - return nodes_by_category[uid] - raise ValueError(f"UID {uid} not found in store.") + def _delete_group_nodes(self, group_name: str, uids: List[str]) -> None: + collection = self._collections.get(group_name) + if collection: + collection.delete(ids=uids) def _build_nodes_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: nodes: List[DocNode] = [] diff --git a/lazyllm/tools/rag/transform.py b/lazyllm/tools/rag/transform.py index 38a97bd2..62b516a8 100644 --- a/lazyllm/tools/rag/transform.py +++ b/lazyllm/tools/rag/transform.py @@ -12,7 +12,7 @@ import nltk import tiktoken -from .store import DocNode, MetadataMode +from .doc_node import DocNode, MetadataMode from lazyllm import LOG, TrainableModule, ThreadPoolExecutor diff --git a/tests/advanced_tests/standard_test/test_reranker.py b/tests/advanced_tests/standard_test/test_reranker.py index 08f9d072..b705e496 100644 --- a/tests/advanced_tests/standard_test/test_reranker.py +++ b/tests/advanced_tests/standard_test/test_reranker.py @@ -1,7 +1,7 @@ import unittest import os import lazyllm -from lazyllm.tools.rag.store import DocNode +from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag.rerank import Reranker, register_reranker diff --git a/tests/basic_tests/test_bm25.py b/tests/basic_tests/test_bm25.py index 1fc9303f..0172e73a 100644 --- a/tests/basic_tests/test_bm25.py +++ b/tests/basic_tests/test_bm25.py @@ -1,6 +1,6 @@ import unittest from lazyllm.tools.rag.component.bm25 import BM25 -from lazyllm.tools.rag.store import DocNode +from lazyllm.tools.rag.doc_node import DocNode import numpy as np diff --git a/tests/basic_tests/test_doc_node.py b/tests/basic_tests/test_doc_node.py index 5c92018c..e49aff18 100644 --- a/tests/basic_tests/test_doc_node.py +++ b/tests/basic_tests/test_doc_node.py @@ -1,5 +1,5 @@ from unittest.mock import MagicMock -from lazyllm.tools.rag.store import DocNode, MetadataMode +from lazyllm.tools.rag.doc_node import DocNode, MetadataMode class TestDocNode: diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index 83f8b3a6..47daa303 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -1,7 +1,8 @@ import lazyllm -from lazyllm.tools.rag.doc_impl import DocImpl +from lazyllm.tools.rag.doc_impl import DocImpl, FileNodeIndex from lazyllm.tools.rag.transform import SentenceSplitter -from lazyllm.tools.rag.store import DocNode, LAZY_ROOT_NAME +from lazyllm.tools.rag.store import LAZY_ROOT_NAME +from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag import Document, Retriever, TransformArgs, AdaptiveTransform from lazyllm.launcher import cleanup from unittest.mock import MagicMock @@ -49,7 +50,7 @@ def test_retrieve(self): group_name="FineChunk", similarity="bm25", similarity_cut_off=-100, - index=None, + index='default', topk=1, similarity_kws={}, ) @@ -59,16 +60,16 @@ def test_retrieve(self): def test_add_files(self): assert self.doc_impl.store is None self.doc_impl._lazy_init() - assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 1 + assert len(self.doc_impl.store.get_group_nodes(LAZY_ROOT_NAME)) == 1 new_doc = DocNode(text="new dummy text", group=LAZY_ROOT_NAME) new_doc.metadata = {"file_name": "new_file.txt"} self.mock_directory_reader.load_data.return_value = [new_doc] self.doc_impl._add_files(["new_file.txt"]) - assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 2 + assert len(self.doc_impl.store.get_group_nodes(LAZY_ROOT_NAME)) == 2 def test_delete_files(self): self.doc_impl._delete_files(["dummy_file.txt"]) - assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 0 + assert len(self.doc_impl.store.get_group_nodes(LAZY_ROOT_NAME)) == 0 class TestDocument(unittest.TestCase): @@ -152,5 +153,38 @@ def test_multi_embedding_with_document(self): assert len(nodes3) == 3 +class TestFileNodeIndex(unittest.TestCase): + def setUp(self): + self.index = FileNodeIndex() + self.node1 = DocNode(uid='1', group=LAZY_ROOT_NAME, metadata={"file_name": "d1"}) + self.node2 = DocNode(uid='2', group=LAZY_ROOT_NAME, metadata={"file_name": "d2"}) + self.files = [self.node1.metadata['file_name'], self.node1.metadata['file_name']] + + def test_update(self): + self.index.update([self.node1, self.node2]) + + nodes = self.index.query(self.files) + assert len(nodes) == len(self.files) + + ret = [node.metadata['file_name'] for node in nodes] + assert set(ret) == set(self.files) + + def test_remove(self): + self.index.update([self.node1, self.node2]) + + self.index.remove([self.node2.uid]) + ret = self.index.query([self.node2.metadata['file_name']]) + assert len(ret) == 1 + assert ret[0] is None + + def test_query(self): + self.index.update([self.node1, self.node2]) + ret = self.index.query([self.node2.metadata['file_name']]) + assert len(ret) == 1 + assert ret[0] is self.node2 + ret = self.index.query([self.node1.metadata['file_name']]) + assert len(ret) == 1 + assert ret[0] is self.node1 + if __name__ == "__main__": unittest.main() diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index 8a5e6178..b4e0804a 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -1,8 +1,16 @@ +import os import time import unittest +import tempfile +import pymilvus from unittest.mock import MagicMock -from lazyllm.tools.rag.store import DocNode, MapStore -from lazyllm.tools.rag.index import DefaultIndex, register_similarity +from lazyllm.tools.rag.store import MapStore, LAZY_ROOT_NAME +from lazyllm.tools.rag.doc_node import DocNode +from lazyllm.tools.rag.index import ( + DefaultIndex, + register_similarity, + MilvusIndex, MilvusEmbeddingField, + parallel_do_embedding) class TestDefaultIndex(unittest.TestCase): @@ -10,7 +18,7 @@ def setUp(self): self.mock_embed = MagicMock(side_effect=self.delayed_embed) self.mock_embed1 = MagicMock(return_value=[0, 1, 0]) self.mock_embed2 = MagicMock(return_value=[0, 0, 1]) - self.mock_store = MagicMock(spec=MapStore) + self.mock_store = MapStore(node_groups=['group1']) # Create instance of DefaultIndex self.index = DefaultIndex(embed={"default": self.mock_embed, @@ -19,13 +27,14 @@ def setUp(self): store=self.mock_store) # Create mock DocNodes - self.doc_node_1 = DocNode("text1") + self.doc_node_1 = DocNode(uid="text1", group="group1") self.doc_node_1.embedding = {"default": [1, 0, 0], "test1": [1, 0, 0], "test2": [1, 0, 0]} - self.doc_node_2 = DocNode("text2") + self.doc_node_2 = DocNode(uid="text2", group="group1") self.doc_node_2.embedding = {"default": [0, 1, 0], "test1": [0, 1, 0], "test2": [0, 1, 0]} - self.doc_node_3 = DocNode("text3") + self.doc_node_3 = DocNode(uid="text3", group="group1") self.doc_node_3.embedding = {"default": [0, 0, 1], "test1": [0, 0, 1], "test2": [0, 0, 1]} self.nodes = [self.doc_node_1, self.doc_node_2, self.doc_node_3] + self.mock_store.update_nodes(self.nodes) # used by index def delayed_embed(self, text): time.sleep(3) @@ -45,7 +54,7 @@ def custom_similarity(query, nodes, **kwargs): def test_query_cosine_similarity(self): results = self.index.query( query="test", - nodes=self.nodes, + group_name="group1", similarity_name="cosine", similarity_cut_off=0.0, topk=2, @@ -59,7 +68,7 @@ def test_invalid_similarity_name(self): with self.assertRaises(ValueError): self.index.query( query="test", - nodes=self.nodes, + group_name="group1", similarity_name="invalid_similarity", similarity_cut_off=0.0, topk=2, @@ -70,13 +79,13 @@ def test_parallel_do_embedding(self): for node in self.nodes: node.has_embedding = MagicMock(return_value=False) start_time = time.time() - self.index._parallel_do_embedding(self.nodes) + parallel_do_embedding(self.index.embed, self.nodes) assert time.time() - start_time < 4, "Parallel not used!" def test_query_multi_embed_similarity(self): results = self.index.query( query="test", - nodes=self.nodes, + group_name="group1", similarity_name="cosine", similarity_cut_off={"default": 0.8, "test1": 0.8, "test2": 0.8}, topk=2, @@ -88,7 +97,7 @@ def test_query_multi_embed_similarity(self): def test_query_multi_embed_one_thresholds(self): results = self.index.query( query="test", - nodes=self.nodes, + group_name="group1", similarity_name="cosine", similarity_cut_off=0.8, embed_keys=["default", "test1"], @@ -98,5 +107,62 @@ def test_query_multi_embed_one_thresholds(self): self.assertEqual(len(results), 1) self.assertIn(self.doc_node_2, results) +class TestMilvusIndex(unittest.TestCase): + def setUp(self): + embedding_fields = [ + MilvusEmbeddingField(name="vec1", dim=3, data_type=pymilvus.DataType.FLOAT_VECTOR, + index_type="HNSW", metric_type="IP"), + MilvusEmbeddingField(name="vec2", dim=5, data_type=pymilvus.DataType.FLOAT_VECTOR, + index_type="HNSW", metric_type="IP"), + ] + group_embedding_fields = { + "group1": embedding_fields, + "group2": embedding_fields, + } + + self.mock_embed = { + 'vec1': MagicMock(return_value=[1.0, 2.0, 3.0]), + 'vec2': MagicMock(return_value=[400.0, 500.0, 600.0, 700.0, 800.0]), + } + + self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] + _, self.store_file = tempfile.mkstemp(suffix=".db") + + self.map_store = MapStore(self.node_groups) + self.index = MilvusIndex(embed=self.mock_embed, + group_embedding_fields=group_embedding_fields, + uri=self.store_file, full_data_store=self.map_store) + self.map_store.register_index(type='milvus', index=self.index) + + self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, + embedding={"vec1": [1.0, 2.0, 3.0], "vec2": [4.0, 5.0, 6.0, 7.0, 8.0]}) + self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, + embedding={"vec1": [100.0, 200.0, 300.0], "vec2": [400.0, 500.0, 600.0, 700.0, 800.0]}) + + def tearDown(self): + os.remove(self.store_file) + + def test_update_and_query(self): + self.map_store.update_nodes([self.node1]) + ret = self.index.query(query='text1', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node1.uid) + + self.map_store.update_nodes([self.node2]) + ret = self.index.query(query='text2', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node2.uid) + + def test_remove_and_query(self): + self.map_store.update_nodes([self.node1, self.node2]) + ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node2.uid) + + self.map_store.remove_group_nodes("group1", [self.node2.uid]) + ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node1.uid) + if __name__ == "__main__": unittest.main() diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 388eb30b..40703014 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -2,7 +2,8 @@ import shutil import unittest import lazyllm -from lazyllm.tools.rag.store import DocNode, ChromadbStore, LAZY_ROOT_NAME +from lazyllm.tools.rag.store import MapStore, ChromadbStore, LAZY_ROOT_NAME +from lazyllm.tools.rag.doc_node import DocNode def clear_directory(directory_path): @@ -25,7 +26,7 @@ def setUp(self): self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] self.embed_dim = {"default": 3} self.store = ChromadbStore(self.node_groups, self.embed_dim) - self.store.add_nodes( + self.store.update_nodes( [DocNode(uid="1", text="text1", group=LAZY_ROOT_NAME, parent=None)], ) @@ -36,31 +37,35 @@ def tearDownClass(cls): def test_initialization(self): self.assertEqual(set(self.store._collections.keys()), set(self.node_groups)) - def test_add_and_traverse_nodes(self): + def test_update_nodes(self): node1 = DocNode(uid="1", text="text1", group="group1") node2 = DocNode(uid="2", text="text2", group="group2") - self.store.add_nodes([node1, node2]) - nodes = self.store.traverse_nodes("group1") + self.store.update_nodes([node1, node2]) + collection = self.store._collections["group1"] + self.assertEqual(set(collection.peek(collection.count())["ids"]), set(["1", "2"])) + nodes = self.store.get_group_nodes("group1") self.assertEqual(nodes, [node1]) - def test_save_nodes(self): + def test_remove_group_nodes(self): node1 = DocNode(uid="1", text="text1", group="group1") node2 = DocNode(uid="2", text="text2", group="group2") - self.store.add_nodes([node1, node2]) + self.store.update_nodes([node1, node2]) collection = self.store._collections["group1"] self.assertEqual(collection.peek(collection.count())["ids"], ["1", "2"]) + self.store.remove_group_nodes("group1", "1") + self.assertEqual(collection.peek(collection.count())["ids"], ["2"]) - def test_try_load_store(self): + def test_load_store(self): # Set up initial data to be loaded node1 = DocNode(uid="1", text="text1", group="group1", parent=None) node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) - self.store.add_nodes([node1, node2]) + self.store.update_nodes([node1, node2]) # Reset store and load from "persistent" storage - self.store._store = {group: {} for group in self.node_groups} - self.store.try_load_store() + self.store._map_store._group2docs = {group: {} for group in self.node_groups} + self.store._load_store() - nodes = self.store.traverse_nodes("group1") + nodes = self.store.get_group_nodes("group1") self.assertEqual(len(nodes), 2) self.assertEqual(nodes[0].uid, "1") self.assertEqual(nodes[1].uid, "2") @@ -85,6 +90,60 @@ def test_insert_dict_as_sparse_embedding(self): for uid, node in nodes_dict.items(): assert node.embedding['default'] == orig_embedding_dict.get(uid) + def test_group_names(self): + self.assertEqual(set(self.store.group_names()), set(self.node_groups)) + + def test_group_others(self): + node1 = DocNode(uid="1", text="text1", group="group1", parent=None) + node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) + self.store.update_nodes([node1, node2]) + self.assertEqual(self.store.group_is_active("group1"), True) + self.assertEqual(self.store.group_is_active("group2"), False) + +class TestMapStore(unittest.TestCase): + def setUp(self): + self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] + self.store = MapStore(self.node_groups) + self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None) + self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1) + + def test_update_nodes(self): + self.store.update_nodes([self.node1, self.node2]) + nodes = self.store.get_group_nodes("group1") + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[0].uid, "1") + self.assertEqual(nodes[1].uid, "2") + self.assertEqual(nodes[1].parent.uid, "1") + + def test_get_group_nodes(self): + self.store.update_nodes([self.node1, self.node2]) + n1 = self.store.get_group_nodes("group1", ["1"])[0] + self.assertEqual(n1.text, self.node1.text) + n2 = self.store.get_group_nodes("group1", ["2"])[0] + self.assertEqual(n2.text, self.node2.text) + ids = set([self.node1.uid, self.node2.uid]) + docs = self.store.get_group_nodes("group1") + self.assertEqual(ids, set([doc.uid for doc in docs])) + + def test_remove_group_nodes(self): + self.store.update_nodes([self.node1, self.node2]) + + n1 = self.store.get_group_nodes("group1", ["1"])[0] + assert n1.text == self.node1.text + self.store.remove_group_nodes("group1", ["1"]) + n1 = self.store.get_group_nodes("group1", ["1"]) + assert not n1 + + n2 = self.store.get_group_nodes("group1", ["2"])[0] + assert n2.text == self.node2.text + self.store.remove_group_nodes("group1", ["2"]) + n2 = self.store.get_group_nodes("group1", ["2"]) + assert not n2 + + def test_group_names(self): + self.assertEqual(set(self.store.group_names()), set(self.node_groups)) -if __name__ == "__main__": - unittest.main() + def test_group_others(self): + self.store.update_nodes([self.node1, self.node2]) + self.assertEqual(self.store.group_is_active("group1"), True) + self.assertEqual(self.store.group_is_active("group2"), False) diff --git a/tests/basic_tests/test_transform.py b/tests/basic_tests/test_transform.py index 5b6a8f29..47f60d60 100644 --- a/tests/basic_tests/test_transform.py +++ b/tests/basic_tests/test_transform.py @@ -1,6 +1,6 @@ import lazyllm from lazyllm.tools.rag.transform import SentenceSplitter -from lazyllm.tools.rag.store import DocNode +from lazyllm.tools.rag.doc_node import DocNode class TestSentenceSplitter: diff --git a/tests/requirements.txt b/tests/requirements.txt index 90ba7fd6..5a33e8db 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,3 +3,4 @@ docx2txt olefile pytest-rerunfailures pytest-order +pymilvus From d4217a63b29487be11228bceaa2d3b392914a928 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 29 Oct 2024 14:30:04 +0800 Subject: [PATCH 03/60] s --- lazyllm/__init__.py | 6 +-- lazyllm/tools/__init__.py | 6 +-- lazyllm/tools/rag/__init__.py | 8 ++-- lazyllm/tools/rag/doc_impl.py | 20 +++++----- lazyllm/tools/rag/doc_node.py | 4 +- lazyllm/tools/rag/document.py | 10 ++--- lazyllm/tools/rag/index.py | 16 ++++---- .../rag/{base_index.py => index_base.py} | 2 +- lazyllm/tools/rag/store.py | 38 +++++++++---------- .../rag/{base_store.py => store_base.py} | 18 ++++----- tests/basic_tests/test_document.py | 6 +-- tests/basic_tests/test_store.py | 20 +++++----- 12 files changed, 76 insertions(+), 78 deletions(-) rename lazyllm/tools/rag/{base_index.py => index_base.py} (97%) rename lazyllm/tools/rag/{base_store.py => store_base.py} (85%) diff --git a/lazyllm/__init__.py b/lazyllm/__init__.py index 15c9125e..72fce42d 100644 --- a/lazyllm/__init__.py +++ b/lazyllm/__init__.py @@ -15,7 +15,7 @@ from .client import redis_client from .tools import (Document, Reranker, Retriever, WebModule, ToolManager, FunctionCall, FunctionCallAgent, fc_register, ReactAgent, PlanAndSolveAgent, ReWOOAgent, SentenceSplitter, - LLMParser, BaseStore, BaseIndex) + LLMParser, StoreBase, IndexBase) from .docs import add_doc config.done() @@ -73,8 +73,8 @@ 'PlanAndSolveAgent', 'ReWOOAgent', 'SentenceSplitter', - 'BaseStore', - 'BaseIndex', + 'StoreBase', + 'IndexBase', # docs 'add_doc', diff --git a/lazyllm/tools/__init__.py b/lazyllm/tools/__init__.py index 52500a5d..31eb249f 100644 --- a/lazyllm/tools/__init__.py +++ b/lazyllm/tools/__init__.py @@ -1,4 +1,4 @@ -from .rag import Document, Reranker, Retriever, SentenceSplitter, LLMParser, BaseStore, BaseIndex +from .rag import Document, Reranker, Retriever, SentenceSplitter, LLMParser, StoreBase, IndexBase from .webpages import WebModule from .agent import ( ToolManager, @@ -32,6 +32,6 @@ "SqlManager", "SqlCall", "HttpTool", - 'BaseStore', - 'BaseIndex', + 'StoreBase', + 'IndexBase', ] diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index 74df9e01..1b438fb4 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -8,8 +8,8 @@ MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader) from .dataReader import SimpleDirectoryReader from .doc_manager import DocManager, DocListManager -from .base_store import BaseStore -from .base_index import BaseIndex +from .store_base import StoreBase +from .index_base import IndexBase __all__ = [ @@ -39,6 +39,6 @@ "SimpleDirectoryReader", 'DocManager', 'DocListManager', - 'BaseStore', - 'BaseIndex', + 'StoreBase', + 'IndexBase', ] diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index e71e585e..75c25995 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -5,16 +5,16 @@ from lazyllm import LOG, config, once_wrapper from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser, AdaptiveTransform, make_transform, TransformArgs) -from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, BaseStore, StoreWrapper +from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, StoreBase, StoreWrapper from .data_loaders import DirectoryReader -from .index import DefaultIndex, BaseIndex +from .index import DefaultIndex, IndexBase from .utils import DocListManager import threading import time _transmap = dict(function=FuncNodeTransform, sentencesplitter=SentenceSplitter, llm=LLMParser) -class FileNodeIndex(BaseIndex): +class FileNodeIndex(IndexBase): def __init__(self): self._file_node_map = {} @@ -60,7 +60,7 @@ class DocImpl: def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = None, doc_files: Optional[str] = None, kb_group_name: Optional[str] = None, - store: Optional[BaseStore] = None): + store: Optional[StoreBase] = None): super().__init__() assert (dlm is None) ^ (doc_files is None), 'Only one of dataset_path or doc_files should be provided' self._local_file_reader: Dict[str, Callable] = {} @@ -80,7 +80,7 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N def _create_file_node_index(store) -> FileNodeIndex: index = FileNodeIndex() for group in store.group_names(): - index.update(store.get_group_nodes(group)) + index.update(store.get_nodes(group)) return index @once_wrapper(reset_on_pickle=True) @@ -109,7 +109,7 @@ def _lazy_init(self) -> None: self._daemon.daemon = True self._daemon.start() - def _create_store(self, rag_store_type: str = None) -> BaseStore: + def _create_store(self, rag_store_type: str = None) -> StoreBase: if not rag_store_type: rag_store_type = config["rag_store_type"] if rag_store_type == "map": @@ -122,7 +122,7 @@ def _create_store(self, rag_store_type: str = None) -> BaseStore: ) return store - def _create_some_indices_for_store(self, store: BaseStore): + def _create_some_indices_for_store(self, store: StoreBase): if not store.get_index(type_name='default'): store.register_index(type_name='default', index=DefaultIndex(self.embed, store)) if not store.get_index(type_name='file_node_map'): @@ -269,7 +269,7 @@ def gather_children(node: DocNode): self.store.remove_group_nodes(group, node_uids) LOG.debug(f"Removed nodes from group {group} for node IDs: {node_uids}") - def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None: + def _dynamic_create_nodes(self, group_name: str, store: StoreBase) -> None: if store.group_is_active(group_name): return node_group = self.node_groups.get(group_name) @@ -283,10 +283,10 @@ def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None: store.update_nodes(nodes) LOG.debug(f"building {group_name} nodes: {nodes}") - def _get_nodes(self, group_name: str, store: Optional[BaseStore] = None) -> List[DocNode]: + def _get_nodes(self, group_name: str, store: Optional[StoreBase] = None) -> List[DocNode]: store = store or self.store self._dynamic_create_nodes(group_name, store) - return store.get_group_nodes(group_name) + return store.get_nodes(group_name) def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_off: Union[float, Dict[str, float]], index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]: diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 919cde5d..2c5021ac 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -16,7 +16,7 @@ class MetadataMode(str, Enum): class DocNode: def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: Optional[str] = None, embedding: Optional[Dict[str, List[float]]] = None, parent: Optional["DocNode"] = None, - metadata: Optional[Dict[str, Any]] = None, classfication: Optional[str] = None): + metadata: Optional[Dict[str, Any]] = None): self.uid: str = uid if uid else str(uuid.uuid4()) self.text: Optional[str] = text self.group: Optional[str] = group @@ -32,8 +32,6 @@ def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: self._docpath = None self._lock = threading.Lock() self._embedding_state = set() - # store will create index cache for classfication to speed up retrieve - self._classfication = classfication @property def root_node(self) -> Optional["DocNode"]: diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index 1e08e912..7a24f021 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -6,7 +6,7 @@ from .doc_manager import DocManager from .doc_impl import DocImpl -from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY, DocNode, BaseStore +from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY, DocNode, StoreBase from .utils import DocListManager import copy import functools @@ -16,7 +16,7 @@ class Document(ModuleBase): class _Impl(ModuleBase): def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, manager: bool = False, server: bool = False, name: Optional[str] = None, launcher=None, - store: BaseStore = None): + store: StoreBase = None): super().__init__() if not os.path.exists(dataset_path): defatult_path = os.path.join(lazyllm.config["data_path"], dataset_path) @@ -34,7 +34,7 @@ def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, if manager: self._manager = DocManager(self._dlm) if server: self._doc = ServerModule(self._doc) - def add_kb_group(self, name, store: BaseStore): + def add_kb_group(self, name, store: StoreBase): self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name, store=store) self._dlm.add_kb_group(name) @@ -42,14 +42,14 @@ def get_doc_by_kb_group(self, name): return self._kbs[name] def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, create_ui: bool = False, manager: bool = False, server: bool = False, - name: Optional[str] = None, launcher=None, store: BaseStore = None): + name: Optional[str] = None, launcher=None, store: StoreBase = None): super().__init__() if create_ui: lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead') self._impls = Document._Impl(dataset_path, embed, create_ui or manager, server, name, launcher, store) self._curr_group = DocListManager.DEDAULT_GROUP_NAME - def create_kb_group(self, name: str, store: BaseStore) -> "Document": + def create_kb_group(self, name: str, store: StoreBase) -> "Document": self._impls.add_kb_group(name, store) doc = copy.copy(self) doc._curr_group = name diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index c9454f17..643323d1 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -2,8 +2,8 @@ import os from typing import List, Callable, Optional, Dict, Union, Tuple from .doc_node import DocNode -from .base_store import BaseStore -from .base_index import BaseIndex +from .store_base import StoreBase +from .index_base import IndexBase import numpy as np from .component.bm25 import BM25 from lazyllm import LOG, config, ThreadPoolExecutor @@ -45,12 +45,12 @@ def parallel_do_embedding(embed: Dict[str, Callable], nodes: List[DocNode]) -> L future.result() return modified_nodes -class DefaultIndex(BaseIndex): +class DefaultIndex(IndexBase): """Default Index, registered for similarity functions""" registered_similarity = dict() - def __init__(self, embed: Dict[str, Callable], store: BaseStore, **kwargs): + def __init__(self, embed: Dict[str, Callable], store: StoreBase, **kwargs): self.embed = embed self.store = store @@ -111,7 +111,7 @@ def query( ) similarity_func, mode, descend = self.registered_similarity[similarity_name] - nodes = self.store.get_group_nodes(group_name) + nodes = self.store.get_nodes(group_name) if mode == "embedding": assert self.embed, "Chosen similarity needs embed model." assert len(query) > 0, "Query should not be empty." @@ -174,7 +174,7 @@ def register_similarity( # ---------------------------------------------------------------------------- # -class MilvusIndex(BaseIndex): +class MilvusIndex(IndexBase): class Field: def __init__(self, name: str, data_type: pymilvus.DataType, index_type: str, metric_type: str, index_params={}, dim: Optional[int] = None): @@ -187,7 +187,7 @@ def __init__(self, name: str, data_type: pymilvus.DataType, index_type: str, def __init__(self, embed: Dict[str, Callable], group_fields: Dict[str, List[MilvusIndex.Field]], - uri: str, full_data_store: BaseStore): + uri: str, full_data_store: StoreBase): self._embed = embed self._full_data_store = full_data_store @@ -255,4 +255,4 @@ def query(self, for result in results[0]: uids.update(result['id']) - return self._full_data_store.get_group_nodes(group_name, list(uids)) + return self._full_data_store.get_nodes(group_name, list(uids)) diff --git a/lazyllm/tools/rag/base_index.py b/lazyllm/tools/rag/index_base.py similarity index 97% rename from lazyllm/tools/rag/base_index.py rename to lazyllm/tools/rag/index_base.py index 543d9d15..a85f26d5 100644 --- a/lazyllm/tools/rag/base_index.py +++ b/lazyllm/tools/rag/index_base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import List, Optional -class BaseIndex(ABC): +class IndexBase(ABC): @abstractmethod def update(nodes: List[DocNode]) -> None: ''' diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index a59c60f8..df66a765 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -2,8 +2,8 @@ import chromadb from lazyllm import LOG, config from chromadb.api.models.Collection import Collection -from .base_store import BaseStore -from .base_index import BaseIndex +from .store_base import StoreBase +from .index_base import IndexBase from .doc_node import DocNode import json @@ -16,8 +16,8 @@ # ---------------------------------------------------------------------------- # -class StoreWrapper(BaseStore): - def __init__(self, store: BaseStore): +class StoreWrapper(StoreBase): + def __init__(self, store: StoreBase): self._store = store self._name2index = {} @@ -25,8 +25,8 @@ def update_nodes(self, nodes: List[DocNode]) -> None: self._store.update_nodes(nodes) self._update_indices(self._name2index, nodes) - def get_group_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: - return self._store.get_group_nodes(group_name, uids) + def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: + return self._store.get_nodes(group_name, uids) def remove_group_nodes(self, group_name: str, uids: List[str] = None) -> None: self._store.remove_group_nodes(group_name, uids) @@ -38,13 +38,13 @@ def group_is_active(self, group_name: str) -> bool: def group_names(self) -> List[str]: return self._store.group_names() - def register_index(self, type_name: str, index: BaseIndex) -> None: + def register_index(self, type_name: str, index: IndexBase) -> None: self._name2index[type_name] = index def remove_index(self, type_name: str) -> None: self._name2index.pop(type_name, None) - def get_index(self, type_name: str) -> Optional[BaseIndex]: + def get_index(self, type_name: str) -> Optional[IndexBase]: index = self._store.get_index(type_name) if not index: index = self._name2index.get(type_name) @@ -52,7 +52,7 @@ def get_index(self, type_name: str) -> Optional[BaseIndex]: # ---------------------------------------------------------------------------- # -class MapStore(BaseStore): +class MapStore(StoreBase): def __init__(self, node_groups: List[str]): # Dict[group_name, Dict[uuid, DocNode]] self._group2docs: Dict[str, Dict[str, DocNode]] = { @@ -68,7 +68,7 @@ def update_nodes(self, nodes: List[DocNode]) -> None: self._update_indices(self._name2index, nodes) # override - def get_group_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: + def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: docs = self._group2docs.get(group_name) if not docs: return [] @@ -106,7 +106,7 @@ def group_names(self) -> List[str]: return self._group2docs.keys() # override - def register_index(self, type_name: str, index: BaseIndex) -> None: + def register_index(self, type_name: str, index: IndexBase) -> None: self._name2index[type_name] = index # override @@ -114,7 +114,7 @@ def remove_index(self, type_name: str) -> None: self._name2index.pop(type_name, None) # override - def get_index(self, type_name: str) -> Optional[BaseIndex]: + def get_index(self, type_name: str) -> Optional[IndexBase]: return self._name2index.get(type_name) def find_node_by_uid(self, uid: str) -> Optional[DocNode]: @@ -126,7 +126,7 @@ def find_node_by_uid(self, uid: str) -> Optional[DocNode]: # ---------------------------------------------------------------------------- # -class ChromadbStore(BaseStore): +class ChromadbStore(StoreBase): def __init__( self, node_groups: List[str], embed_dim: Dict[str, int] ) -> None: @@ -145,8 +145,8 @@ def update_nodes(self, nodes: List[DocNode]) -> None: self._save_nodes(nodes) # override - def get_group_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: - return self._map_store.get_group_nodes(group_name, uids) + def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: + return self._map_store.get_nodes(group_name, uids) # override def remove_group_nodes(self, group_name: str, uids: List[str]) -> None: @@ -165,15 +165,15 @@ def group_names(self) -> List[str]: return self._map_store.group_names() # override - def register_index(self, type_name: str, index: BaseIndex) -> None: + def register_index(self, type_name: str, index: IndexBase) -> None: self._map_store.register_index(type_name, index) # override - def remove_index(self, type_name: str) -> Optional[BaseIndex]: + def remove_index(self, type_name: str) -> Optional[IndexBase]: return self._map_store.remove_index(type_name) # override - def get_index(self, type_name: str) -> Optional[BaseIndex]: + def get_index(self, type_name: str) -> Optional[IndexBase]: return self._map_store.get_index(type_name) def _load_store(self) -> None: @@ -189,7 +189,7 @@ def _load_store(self) -> None: # Rebuild relationships for group_name in self._map_store.group_names(): - nodes = self._map_store.get_group_nodes(group_name) + nodes = self._map_store.get_nodes(group_name) for node in nodes: if node.parent: parent_uid = node.parent diff --git a/lazyllm/tools/rag/base_store.py b/lazyllm/tools/rag/store_base.py similarity index 85% rename from lazyllm/tools/rag/base_store.py rename to lazyllm/tools/rag/store_base.py index 384a4873..cdf9ee00 100644 --- a/lazyllm/tools/rag/base_store.py +++ b/lazyllm/tools/rag/store_base.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod from typing import Optional, List, Dict from .doc_node import DocNode -from .base_index import BaseIndex +from .index_base import IndexBase -class BaseStore(ABC): +class StoreBase(ABC): @abstractmethod def update_nodes(self, nodes: List[DocNode]) -> None: ''' @@ -15,7 +15,7 @@ def update_nodes(self, nodes: List[DocNode]) -> None: raise NotImplementedError("not implemented yet.") @abstractmethod - def get_group_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: + def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: ''' Returns a list of `DocNode` specified by `uids` in the group named `group_name`. All `DocNode`s in the group `group_name` will be returned if `uids` is `None` or `[]`. @@ -65,13 +65,13 @@ def group_names(self) -> List[str]: raise NotImplementedError("not implemented yet.") @abstractmethod - def register_index(self, type_name: str, index: BaseIndex) -> None: + def register_index(self, type_name: str, index: IndexBase) -> None: ''' Registers `index` with type `type` to this store. Args: type_name (str): type of the index to be registered. - index (BaseIndex): the index to be registered. + index (IndexBase): the index to be registered. ''' raise NotImplementedError("not implemented yet.") @@ -86,7 +86,7 @@ def remove_index(self, type_name: str) -> None: raise NotImplementedError("not implemented yet.") @abstractmethod - def get_index(self, type_name: str) -> Optional[BaseIndex]: + def get_index(self, type_name: str) -> Optional[IndexBase]: ''' Returns index with the specified type `type` in this store. @@ -94,19 +94,19 @@ def get_index(self, type_name: str) -> Optional[BaseIndex]: type_name (str): type of the index to be removed. Returns: - Optional[BaseIndex]: the index of specified type, or `None`. + Optional[IndexBase]: the index of specified type, or `None`. ''' raise NotImplementedError("not implemented yet.") # ----- helper functions ----- # @staticmethod - def _update_indices(name2index: Dict[str, BaseIndex], nodes: List[DocNode]) -> None: + def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: for _, index in name2index.items(): index.update(nodes) @staticmethod - def _remove_from_indices(name2index: Dict[str, BaseIndex], uids: List[str], + def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], group_name: Optional[str] = None) -> None: for _, index in name2index.items(): index.remove(uids, group_name) diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index 47daa303..669040bb 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -60,16 +60,16 @@ def test_retrieve(self): def test_add_files(self): assert self.doc_impl.store is None self.doc_impl._lazy_init() - assert len(self.doc_impl.store.get_group_nodes(LAZY_ROOT_NAME)) == 1 + assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 1 new_doc = DocNode(text="new dummy text", group=LAZY_ROOT_NAME) new_doc.metadata = {"file_name": "new_file.txt"} self.mock_directory_reader.load_data.return_value = [new_doc] self.doc_impl._add_files(["new_file.txt"]) - assert len(self.doc_impl.store.get_group_nodes(LAZY_ROOT_NAME)) == 2 + assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 2 def test_delete_files(self): self.doc_impl._delete_files(["dummy_file.txt"]) - assert len(self.doc_impl.store.get_group_nodes(LAZY_ROOT_NAME)) == 0 + assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 0 class TestDocument(unittest.TestCase): diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 40703014..ec92152c 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -43,7 +43,7 @@ def test_update_nodes(self): self.store.update_nodes([node1, node2]) collection = self.store._collections["group1"] self.assertEqual(set(collection.peek(collection.count())["ids"]), set(["1", "2"])) - nodes = self.store.get_group_nodes("group1") + nodes = self.store.get_nodes("group1") self.assertEqual(nodes, [node1]) def test_remove_group_nodes(self): @@ -65,7 +65,7 @@ def test_load_store(self): self.store._map_store._group2docs = {group: {} for group in self.node_groups} self.store._load_store() - nodes = self.store.get_group_nodes("group1") + nodes = self.store.get_nodes("group1") self.assertEqual(len(nodes), 2) self.assertEqual(nodes[0].uid, "1") self.assertEqual(nodes[1].uid, "2") @@ -109,7 +109,7 @@ def setUp(self): def test_update_nodes(self): self.store.update_nodes([self.node1, self.node2]) - nodes = self.store.get_group_nodes("group1") + nodes = self.store.get_nodes("group1") self.assertEqual(len(nodes), 2) self.assertEqual(nodes[0].uid, "1") self.assertEqual(nodes[1].uid, "2") @@ -117,27 +117,27 @@ def test_update_nodes(self): def test_get_group_nodes(self): self.store.update_nodes([self.node1, self.node2]) - n1 = self.store.get_group_nodes("group1", ["1"])[0] + n1 = self.store.get_nodes("group1", ["1"])[0] self.assertEqual(n1.text, self.node1.text) - n2 = self.store.get_group_nodes("group1", ["2"])[0] + n2 = self.store.get_nodes("group1", ["2"])[0] self.assertEqual(n2.text, self.node2.text) ids = set([self.node1.uid, self.node2.uid]) - docs = self.store.get_group_nodes("group1") + docs = self.store.get_nodes("group1") self.assertEqual(ids, set([doc.uid for doc in docs])) def test_remove_group_nodes(self): self.store.update_nodes([self.node1, self.node2]) - n1 = self.store.get_group_nodes("group1", ["1"])[0] + n1 = self.store.get_nodes("group1", ["1"])[0] assert n1.text == self.node1.text self.store.remove_group_nodes("group1", ["1"]) - n1 = self.store.get_group_nodes("group1", ["1"]) + n1 = self.store.get_nodes("group1", ["1"]) assert not n1 - n2 = self.store.get_group_nodes("group1", ["2"])[0] + n2 = self.store.get_nodes("group1", ["2"])[0] assert n2.text == self.node2.text self.store.remove_group_nodes("group1", ["2"]) - n2 = self.store.get_group_nodes("group1", ["2"]) + n2 = self.store.get_nodes("group1", ["2"]) assert not n2 def test_group_names(self): From 827755f20b027356cb146c8520f62c9278f518e1 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 29 Oct 2024 14:33:03 +0800 Subject: [PATCH 04/60] s --- lazyllm/tools/rag/doc_impl.py | 2 +- lazyllm/tools/rag/store.py | 10 +++++----- lazyllm/tools/rag/store_base.py | 2 +- tests/basic_tests/test_index.py | 2 +- tests/basic_tests/test_store.py | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 75c25995..e6fc9408 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -266,7 +266,7 @@ def gather_children(node: DocNode): # Delete nodes in all groups for group, node_uids in uids_to_delete.items(): - self.store.remove_group_nodes(group, node_uids) + self.store.remove_nodes(group, node_uids) LOG.debug(f"Removed nodes from group {group} for node IDs: {node_uids}") def _dynamic_create_nodes(self, group_name: str, store: StoreBase) -> None: diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index df66a765..6fcb453e 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -28,8 +28,8 @@ def update_nodes(self, nodes: List[DocNode]) -> None: def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: return self._store.get_nodes(group_name, uids) - def remove_group_nodes(self, group_name: str, uids: List[str] = None) -> None: - self._store.remove_group_nodes(group_name, uids) + def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: + self._store.remove_nodes(group_name, uids) self._remove_from_indices(self._name2index, uids, group_name) def group_is_active(self, group_name: str) -> bool: @@ -84,7 +84,7 @@ def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: return ret # override - def remove_group_nodes(self, group_name: str, uids: List[str] = None) -> None: + def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: if uids: docs = self._group2docs.get(group_name) if docs: @@ -149,12 +149,12 @@ def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: return self._map_store.get_nodes(group_name, uids) # override - def remove_group_nodes(self, group_name: str, uids: List[str]) -> None: + def remove_nodes(self, group_name: str, uids: List[str]) -> None: if uids: self._delete_group_nodes(group_name, uids) else: self._db_client.delete_collection(name=group_name) - return self._map_store.remove_group_nodes(group_name, uids) + return self._map_store.remove_nodes(group_name, uids) # override def group_is_active(self, group_name: str) -> bool: diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index cdf9ee00..eb19cdf5 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -30,7 +30,7 @@ def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[D raise NotImplementedError("not implemented yet.") @abstractmethod - def remove_group_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: + def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: ''' Removes sepcified `DocNode`s in the group named `group_name`. Group `group_name` will be removed if `uids` is `None` or `[]`. diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index b4e0804a..2538d9cc 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -159,7 +159,7 @@ def test_remove_and_query(self): self.assertEqual(len(ret), 1) self.assertEqual(ret[0].uid, self.node2.uid) - self.map_store.remove_group_nodes("group1", [self.node2.uid]) + self.map_store.remove_nodes("group1", [self.node2.uid]) ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) self.assertEqual(len(ret), 1) self.assertEqual(ret[0].uid, self.node1.uid) diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index ec92152c..fa619a7b 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -52,7 +52,7 @@ def test_remove_group_nodes(self): self.store.update_nodes([node1, node2]) collection = self.store._collections["group1"] self.assertEqual(collection.peek(collection.count())["ids"], ["1", "2"]) - self.store.remove_group_nodes("group1", "1") + self.store.remove_nodes("group1", "1") self.assertEqual(collection.peek(collection.count())["ids"], ["2"]) def test_load_store(self): @@ -130,13 +130,13 @@ def test_remove_group_nodes(self): n1 = self.store.get_nodes("group1", ["1"])[0] assert n1.text == self.node1.text - self.store.remove_group_nodes("group1", ["1"]) + self.store.remove_nodes("group1", ["1"]) n1 = self.store.get_nodes("group1", ["1"]) assert not n1 n2 = self.store.get_nodes("group1", ["2"])[0] assert n2.text == self.node2.text - self.store.remove_group_nodes("group1", ["2"]) + self.store.remove_nodes("group1", ["2"]) n2 = self.store.get_nodes("group1", ["2"]) assert not n2 From b8fa14b4ab07628c4fe07c2f0edfa2a936ad8624 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 29 Oct 2024 14:37:10 +0800 Subject: [PATCH 05/60] s --- lazyllm/tools/rag/doc_impl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index e6fc9408..5b8daadb 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -292,8 +292,7 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_ index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]: self._lazy_init() - index_instance = self.store.get_index(type_name=index) - if not index_instance: + if not index_instance := self.store.get_index(type_name=index): raise NotImplementedError(f"index type '{index}' is not supported currently.") self._dynamic_create_nodes(group_name, self.store) From 71161f6908da6791349cdfc6ea86413ee21a11c0 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 29 Oct 2024 14:45:08 +0800 Subject: [PATCH 06/60] s --- lazyllm/tools/rag/index_base.py | 20 ++------- lazyllm/tools/rag/store_base.py | 78 ++++----------------------------- 2 files changed, 11 insertions(+), 87 deletions(-) diff --git a/lazyllm/tools/rag/index_base.py b/lazyllm/tools/rag/index_base.py index a85f26d5..ca2fe653 100644 --- a/lazyllm/tools/rag/index_base.py +++ b/lazyllm/tools/rag/index_base.py @@ -5,26 +5,12 @@ class IndexBase(ABC): @abstractmethod def update(nodes: List[DocNode]) -> None: - ''' - Inserts or updates a list of `DocNode` to this index. - - Args: - nodes (List[DocNode]): nodes to be inserted or updated. - ''' - raise NotImplementedError("not implemented yet.") + pass @abstractmethod def remove(uids: List[str], group_name: Optional[str] = None) -> None: - ''' - Removes `DocNode`s sepcified by `uids`. If `group_name` is not None, - just remove uids from that group. - - Args: - uids (List[str]): a list of doc ids. - group_name (Optional[str]): name of the group. - ''' - raise NotImplementedError("not implemented yet.") + pass @abstractmethod def query(self, *args, **kwargs) -> List[DocNode]: - raise NotImplementedError("not implemented yet.") + pass diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index eb19cdf5..5a07ed95 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -6,97 +6,35 @@ class StoreBase(ABC): @abstractmethod def update_nodes(self, nodes: List[DocNode]) -> None: - ''' - Inserts or updates a list of `DocNode` to this store. - - Args: - nodes (List[DocNode]): nodes to be inserted or updated. - ''' - raise NotImplementedError("not implemented yet.") + pass @abstractmethod def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: - ''' - Returns a list of `DocNode` specified by `uids` in the group named `group_name`. - All `DocNode`s in the group `group_name` will be returned if `uids` is `None` or `[]`. - - Args: - group_name (str): the name of group. - uids (List[str]): a list of doc ids. - - Returns: - List[DocNode]: the result. - ''' - raise NotImplementedError("not implemented yet.") + pass @abstractmethod def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: - ''' - Removes sepcified `DocNode`s in the group named `group_name`. - Group `group_name` will be removed if `uids` is `None` or `[]`. - - Args: - group_name (str): the name of group. - uids (List[str]): a list of doc ids. - ''' - raise NotImplementedError("not implemented yet.") + pass @abstractmethod def group_is_active(self, group_name: str) -> bool: - ''' - Returns `True` if a group named `group_name` exists or has at least one `DocNode`. - - Args: - group_name (str): the name of group. - - Returns: - bool: whether the group `group_name` is active. - ''' - raise NotImplementedError("not implemented yet.") + pass @abstractmethod def group_names(self) -> List[str]: - ''' - Returns group names in this store. - - Returns: - List[str]: the result. - ''' - raise NotImplementedError("not implemented yet.") + pass @abstractmethod def register_index(self, type_name: str, index: IndexBase) -> None: - ''' - Registers `index` with type `type` to this store. - - Args: - type_name (str): type of the index to be registered. - index (IndexBase): the index to be registered. - ''' - raise NotImplementedError("not implemented yet.") + pass @abstractmethod def remove_index(self, type_name: str) -> None: - ''' - Removes index with type `type` in this store. - - Args: - type_name (str): type of the index to be removed. - ''' - raise NotImplementedError("not implemented yet.") + pass @abstractmethod def get_index(self, type_name: str) -> Optional[IndexBase]: - ''' - Returns index with the specified type `type` in this store. - - Args: - type_name (str): type of the index to be removed. - - Returns: - Optional[IndexBase]: the index of specified type, or `None`. - ''' - raise NotImplementedError("not implemented yet.") + pass # ----- helper functions ----- # From 5c30f5f09693b971841ff0fc0ea29428bbf7b401 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 29 Oct 2024 15:43:00 +0800 Subject: [PATCH 07/60] s --- lazyllm/tools/rag/doc_impl.py | 14 +++++----- lazyllm/tools/rag/store.py | 12 ++++----- lazyllm/tools/rag/store_base.py | 46 ++++++++++++++------------------- tests/basic_tests/test_store.py | 4 +-- 4 files changed, 35 insertions(+), 41 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 5b8daadb..8428753d 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -14,7 +14,7 @@ _transmap = dict(function=FuncNodeTransform, sentencesplitter=SentenceSplitter, llm=LLMParser) -class FileNodeIndex(IndexBase): +class _FileNodeIndex(IndexBase): def __init__(self): self._file_node_map = {} @@ -77,9 +77,9 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N self.store = None @staticmethod - def _create_file_node_index(store) -> FileNodeIndex: - index = FileNodeIndex() - for group in store.group_names(): + def _create_file_node_index(store) -> _FileNodeIndex: + index = _FileNodeIndex() + for group in store.all_groups(): index.update(store.get_nodes(group)) return index @@ -232,9 +232,9 @@ def _add_files(self, input_files: List[str]): root_nodes = self._reader.load_data(input_files) temp_store = self._create_store("map") temp_store.update_nodes(root_nodes) - group_names = self.store.group_names() - LOG.info(f"add_files: Trying to merge store with {group_names}") - for group in group_names: + all_groups = self.store.all_groups() + LOG.info(f"add_files: Trying to merge store with {all_groups}") + for group in all_groups: if not self.store.group_is_active(group): continue # Duplicate group will be discarded automatically diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 6fcb453e..280be62a 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -35,8 +35,8 @@ def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: def group_is_active(self, group_name: str) -> bool: return self._store.group_is_active(group_name) - def group_names(self) -> List[str]: - return self._store.group_names() + def all_groups(self) -> List[str]: + return self._store.all_groups() def register_index(self, type_name: str, index: IndexBase) -> None: self._name2index[type_name] = index @@ -102,7 +102,7 @@ def group_is_active(self, group_name: str) -> bool: return True if docs else False # override - def group_names(self) -> List[str]: + def all_groups(self) -> List[str]: return self._group2docs.keys() # override @@ -161,8 +161,8 @@ def group_is_active(self, group_name: str) -> bool: return self._map_store.group_is_active(group_name) # override - def group_names(self) -> List[str]: - return self._map_store.group_names() + def all_groups(self) -> List[str]: + return self._map_store.all_groups() # override def register_index(self, type_name: str, index: IndexBase) -> None: @@ -188,7 +188,7 @@ def _load_store(self) -> None: self._map_store.update_nodes(nodes) # Rebuild relationships - for group_name in self._map_store.group_names(): + for group_name in self._map_store.all_groups(): nodes = self._map_store.get_nodes(group_name) for node in nodes: if node.parent: diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index 5a07ed95..d3b44be9 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -4,47 +4,41 @@ from .index_base import IndexBase class StoreBase(ABC): - @abstractmethod - def update_nodes(self, nodes: List[DocNode]) -> None: - pass + def __init__(self): + self._name2index = {} - @abstractmethod - def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: - pass + def register_index(self, type_name: str, index: IndexBase) -> None: + self._name2index[type_name] = index + + def get_index(self, type_name: str) -> Optional[IndexBase]: + return self._name2index.get(type_name) + + def update_nodes(self, nodes: List[DocNode]) -> None: + self._update_nodes(nodes) + for _, index in self._name2index.items(): + index.update(nodes) - @abstractmethod def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: - pass + self._remove_nodes(group_name, uids) + for _, index in self._name2index.items(): + index.remove(uids, group_name) @abstractmethod - def group_is_active(self, group_name: str) -> bool: + def _update_nodes(self, nodes: List[DocNode]) -> None: pass @abstractmethod - def group_names(self) -> List[str]: + def _remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: pass @abstractmethod - def register_index(self, type_name: str, index: IndexBase) -> None: + def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: pass @abstractmethod - def remove_index(self, type_name: str) -> None: + def group_is_active(self, group_name: str) -> bool: pass @abstractmethod - def get_index(self, type_name: str) -> Optional[IndexBase]: + def all_groups(self) -> List[str]: pass - - # ----- helper functions ----- # - - @staticmethod - def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: - for _, index in name2index.items(): - index.update(nodes) - - @staticmethod - def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], - group_name: Optional[str] = None) -> None: - for _, index in name2index.items(): - index.remove(uids, group_name) diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index fa619a7b..aa5da1ee 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -91,7 +91,7 @@ def test_insert_dict_as_sparse_embedding(self): assert node.embedding['default'] == orig_embedding_dict.get(uid) def test_group_names(self): - self.assertEqual(set(self.store.group_names()), set(self.node_groups)) + self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) def test_group_others(self): node1 = DocNode(uid="1", text="text1", group="group1", parent=None) @@ -141,7 +141,7 @@ def test_remove_group_nodes(self): assert not n2 def test_group_names(self): - self.assertEqual(set(self.store.group_names()), set(self.node_groups)) + self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) def test_group_others(self): self.store.update_nodes([self.node1, self.node2]) From 6e09bb3428c816d0b92b4e698718ae8bcf0ec903 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 29 Oct 2024 16:23:56 +0800 Subject: [PATCH 08/60] s --- lazyllm/tools/rag/doc_impl.py | 81 ++++++++++++++++++++++++--------- lazyllm/tools/rag/store.py | 68 +++++---------------------- lazyllm/tools/rag/store_base.py | 35 ++++++++------ tests/basic_tests/test_store.py | 8 ++-- 4 files changed, 97 insertions(+), 95 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 8428753d..cd3d002b 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -5,7 +5,7 @@ from lazyllm import LOG, config, once_wrapper from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser, AdaptiveTransform, make_transform, TransformArgs) -from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, StoreBase, StoreWrapper +from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, StoreBase from .data_loaders import DirectoryReader from .index import DefaultIndex, IndexBase from .utils import DocListManager @@ -14,6 +14,8 @@ _transmap = dict(function=FuncNodeTransform, sentencesplitter=SentenceSplitter, llm=LLMParser) +# ---------------------------------------------------------------------------- # + class _FileNodeIndex(IndexBase): def __init__(self): self._file_node_map = {} @@ -40,6 +42,50 @@ def query(self, files: List[str]) -> List[DocNode]: ret.append(self._file_node_map.get(file)) return ret +class _DocStore(StoreBase): + @staticmethod + def _create_file_node_index(store) -> _FileNodeIndex: + index = _FileNodeIndex() + for group in store.all_groups(): + index.update(store.get_nodes(group)) + return index + + def _create_some_indices(self): + if not self._store.get_index(type='file_node_map'): + self.register_index(type='file_node_map', index=self._create_file_node_index(self._store)) + + def __init__(self, store: StoreBase): + self._store = store + self._extra_indices = {} + self._create_some_indices() + + def update_nodes(self, nodes: List[DocNode]) -> None: + self._store.update_nodes(nodes) + self._update_indices(self._extra_indices, nodes) + + def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: + return self._store.get_nodes(group_name, uids) + + def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: + self._store.remove_nodes(group_name, uids) + self._remove_from_indices(self._extra_indices, uids, group_name) + + def is_group_active(self, name: str) -> bool: + return self._store.is_group_active(name) + + def all_groups(self) -> List[str]: + return self._store.all_groups() + + def register_index(self, type: str, index: IndexBase) -> None: + self._extra_indices[type] = index + + def get_index(self, type: str) -> Optional[IndexBase]: + index = self._extra_indices.get(type) + if not index: + index = self._store.get_index(type) + return index + +# ---------------------------------------------------------------------------- # def embed_wrapper(func): if not func: @@ -71,18 +117,10 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N self.embed = {k: embed_wrapper(e) for k, e in embed.items()} self._embed_dim = None if store: - self.store = StoreWrapper(store) - self._create_some_indices_for_store(self.store) + self.store = _DocStore(store) else: self.store = None - @staticmethod - def _create_file_node_index(store) -> _FileNodeIndex: - index = _FileNodeIndex() - for group in store.all_groups(): - index.update(store.get_nodes(group)) - return index - @once_wrapper(reset_on_pickle=True) def _lazy_init(self) -> None: node_groups = DocImpl._builtin_node_groups.copy() @@ -94,9 +132,8 @@ def _lazy_init(self) -> None: if not self.store: self.store = self._create_store() - self._create_some_indices_for_store(self.store) - if not self.store.group_is_active(LAZY_ROOT_NAME): + if not self.store.is_group_active(LAZY_ROOT_NAME): ids, pathes = self._list_files() root_nodes = self._reader.load_data(pathes) self.store.update_nodes(root_nodes) @@ -120,13 +157,13 @@ def _create_store(self, rag_store_type: str = None) -> StoreBase: raise NotImplementedError( f"Not implemented store type for {rag_store_type}" ) - return store - def _create_some_indices_for_store(self, store: StoreBase): - if not store.get_index(type_name='default'): - store.register_index(type_name='default', index=DefaultIndex(self.embed, store)) - if not store.get_index(type_name='file_node_map'): - store.register_index(type_name='file_node_map', index=self._create_file_node_index(store)) + if not store.get_index(type='default'): + store.register_index(type='default', index=DefaultIndex(self.embed, store)) + if not store.get_index(type='file_node_map'): + store.register_index(type='file_node_map', index=self._create_file_node_index(store)) + + return store @staticmethod def _create_node_group_impl(cls, group_name, name, transform: Union[str, Callable] = None, @@ -235,7 +272,7 @@ def _add_files(self, input_files: List[str]): all_groups = self.store.all_groups() LOG.info(f"add_files: Trying to merge store with {all_groups}") for group in all_groups: - if not self.store.group_is_active(group): + if not self.store.is_group_active(group): continue # Duplicate group will be discarded automatically nodes = self._get_nodes(group, temp_store) @@ -244,7 +281,7 @@ def _add_files(self, input_files: List[str]): def _delete_files(self, input_files: List[str]) -> None: self._lazy_init() - docs = self.store.get_index(type_name='file_node_map').query(input_files) + docs = self.store.get_index(type='file_node_map').query(input_files) LOG.info(f"delete_files: removing documents {input_files} and nodes {docs}") if len(docs) == 0: return @@ -270,7 +307,7 @@ def gather_children(node: DocNode): LOG.debug(f"Removed nodes from group {group} for node IDs: {node_uids}") def _dynamic_create_nodes(self, group_name: str, store: StoreBase) -> None: - if store.group_is_active(group_name): + if store.is_group_active(group_name): return node_group = self.node_groups.get(group_name) if node_group is None: @@ -292,7 +329,7 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_ index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]: self._lazy_init() - if not index_instance := self.store.get_index(type_name=index): + if not index_instance := self.store.get_index(type=index): raise NotImplementedError(f"index type '{index}' is not supported currently.") self._dynamic_create_nodes(group_name, self.store) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 280be62a..95344c63 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -16,42 +16,6 @@ # ---------------------------------------------------------------------------- # -class StoreWrapper(StoreBase): - def __init__(self, store: StoreBase): - self._store = store - self._name2index = {} - - def update_nodes(self, nodes: List[DocNode]) -> None: - self._store.update_nodes(nodes) - self._update_indices(self._name2index, nodes) - - def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: - return self._store.get_nodes(group_name, uids) - - def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: - self._store.remove_nodes(group_name, uids) - self._remove_from_indices(self._name2index, uids, group_name) - - def group_is_active(self, group_name: str) -> bool: - return self._store.group_is_active(group_name) - - def all_groups(self) -> List[str]: - return self._store.all_groups() - - def register_index(self, type_name: str, index: IndexBase) -> None: - self._name2index[type_name] = index - - def remove_index(self, type_name: str) -> None: - self._name2index.pop(type_name, None) - - def get_index(self, type_name: str) -> Optional[IndexBase]: - index = self._store.get_index(type_name) - if not index: - index = self._name2index.get(type_name) - return index - -# ---------------------------------------------------------------------------- # - class MapStore(StoreBase): def __init__(self, node_groups: List[str]): # Dict[group_name, Dict[uuid, DocNode]] @@ -97,8 +61,8 @@ def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: self._remove_from_indices(self._name2index, [doc.uid for doc in docs]) # override - def group_is_active(self, group_name: str) -> bool: - docs = self._group2docs.get(group_name) + def is_group_active(self, name: str) -> bool: + docs = self._group2docs.get(name) return True if docs else False # override @@ -106,16 +70,12 @@ def all_groups(self) -> List[str]: return self._group2docs.keys() # override - def register_index(self, type_name: str, index: IndexBase) -> None: - self._name2index[type_name] = index - - # override - def remove_index(self, type_name: str) -> None: - self._name2index.pop(type_name, None) + def register_index(self, type: str, index: IndexBase) -> None: + self._name2index[type] = index # override - def get_index(self, type_name: str) -> Optional[IndexBase]: - return self._name2index.get(type_name) + def get_index(self, type: str) -> Optional[IndexBase]: + return self._name2index.get(type) def find_node_by_uid(self, uid: str) -> Optional[DocNode]: for docs in self._group2docs.values(): @@ -157,24 +117,20 @@ def remove_nodes(self, group_name: str, uids: List[str]) -> None: return self._map_store.remove_nodes(group_name, uids) # override - def group_is_active(self, group_name: str) -> bool: - return self._map_store.group_is_active(group_name) + def is_group_active(self, name: str) -> bool: + return self._map_store.is_group_active(name) # override def all_groups(self) -> List[str]: return self._map_store.all_groups() # override - def register_index(self, type_name: str, index: IndexBase) -> None: - self._map_store.register_index(type_name, index) - - # override - def remove_index(self, type_name: str) -> Optional[IndexBase]: - return self._map_store.remove_index(type_name) + def register_index(self, type: str, index: IndexBase) -> None: + self._map_store.register_index(type, index) # override - def get_index(self, type_name: str) -> Optional[IndexBase]: - return self._map_store.get_index(type_name) + def get_index(self, type: str) -> Optional[IndexBase]: + return self._map_store.get_index(type) def _load_store(self) -> None: if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index d3b44be9..9e40f7ff 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -7,38 +7,47 @@ class StoreBase(ABC): def __init__(self): self._name2index = {} - def register_index(self, type_name: str, index: IndexBase) -> None: - self._name2index[type_name] = index + def register_index(self, type: str, index: IndexBase) -> None: + self._name2index[type] = index - def get_index(self, type_name: str) -> Optional[IndexBase]: - return self._name2index.get(type_name) + def get_index(self, type: str) -> Optional[IndexBase]: + return self._name2index.get(type) def update_nodes(self, nodes: List[DocNode]) -> None: self._update_nodes(nodes) - for _, index in self._name2index.items(): - index.update(nodes) + self._update_indices(self._name2index, nodes) def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: self._remove_nodes(group_name, uids) - for _, index in self._name2index.items(): - index.remove(uids, group_name) + self._remove_from_indices(self._name2index, uids, group_name) @abstractmethod - def _update_nodes(self, nodes: List[DocNode]) -> None: + def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: pass @abstractmethod - def _remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: + def is_group_active(self, name: str) -> bool: pass @abstractmethod - def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: + def all_groups(self) -> List[str]: pass @abstractmethod - def group_is_active(self, group_name: str) -> bool: + def _update_nodes(self, nodes: List[DocNode]) -> None: pass @abstractmethod - def all_groups(self) -> List[str]: + def _remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: pass + + @staticmethod + def _update_indices(name2index: Dict[str, BaseIndex], nodes: List[DocNode]) -> None: + for _, index in name2index.items(): + index.update(nodes) + + @staticmethod + def _remove_from_indices(name2index: Dict[str, BaseIndex], uids: List[str], + group_name: Optional[str] = None) -> None: + for _, index in name2index.items(): + index.remove(uids, group_name) diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index aa5da1ee..d3d02ae4 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -97,8 +97,8 @@ def test_group_others(self): node1 = DocNode(uid="1", text="text1", group="group1", parent=None) node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) self.store.update_nodes([node1, node2]) - self.assertEqual(self.store.group_is_active("group1"), True) - self.assertEqual(self.store.group_is_active("group2"), False) + self.assertEqual(self.store.is_group_active("group1"), True) + self.assertEqual(self.store.is_group_active("group2"), False) class TestMapStore(unittest.TestCase): def setUp(self): @@ -145,5 +145,5 @@ def test_group_names(self): def test_group_others(self): self.store.update_nodes([self.node1, self.node2]) - self.assertEqual(self.store.group_is_active("group1"), True) - self.assertEqual(self.store.group_is_active("group2"), False) + self.assertEqual(self.store.is_group_active("group1"), True) + self.assertEqual(self.store.is_group_active("group2"), False) From 71f2df212fd0174e7b087bea7d698ac5e3163eac Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 29 Oct 2024 17:31:29 +0800 Subject: [PATCH 09/60] s --- lazyllm/common/__init__.py | 3 +- lazyllm/common/common.py | 6 +++ lazyllm/tools/rag/doc_impl.py | 10 ++-- lazyllm/tools/rag/index.py | 52 ++++++++++---------- lazyllm/tools/rag/store.py | 85 +++++++++++++-------------------- lazyllm/tools/rag/store_base.py | 4 +- tests/basic_tests/test_store.py | 2 +- 7 files changed, 77 insertions(+), 85 deletions(-) diff --git a/lazyllm/common/__init__.py b/lazyllm/common/__init__.py index 5ceff719..a7549359 100644 --- a/lazyllm/common/__init__.py +++ b/lazyllm/common/__init__.py @@ -1,5 +1,5 @@ from .registry import LazyLLMRegisterMetaClass, _get_base_cls_from_registry, Register -from .common import package, kwargs, arguments, LazyLLMCMD, timeout, final, ReadOnlyWrapper, DynamicDescriptor +from .common import package, kwargs, arguments, LazyLLMCMD, timeout, final, ReadOnlyWrapper, DynamicDescriptor, override from .common import FlatList, Identity, ResultCollector, ArgsDict, CaseInsensitiveDict from .common import ReprRule, make_repr, modify_repr from .common import once_flag, call_once, once_wrapper, singleton @@ -38,6 +38,7 @@ 'package', 'kwargs', 'arguments', + 'override', # option 'Option', diff --git a/lazyllm/common/common.py b/lazyllm/common/common.py index 1d2cd649..d7477420 100644 --- a/lazyllm/common/common.py +++ b/lazyllm/common/common.py @@ -14,6 +14,12 @@ _F = typing.TypeVar("_F", bound=Callable[..., Any]) def final(f: _F) -> _F: return f +try: + from typing import override +except ImportError: + def override(func: Callable): + return func + class FlatList(list): def absorb(self, item): diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index cd3d002b..268bc4e6 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -3,6 +3,7 @@ from functools import wraps from typing import Callable, Dict, List, Optional, Set, Union, Tuple from lazyllm import LOG, config, once_wrapper +from lazyllm.common import override from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser, AdaptiveTransform, make_transform, TransformArgs) from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, StoreBase @@ -20,7 +21,7 @@ class _FileNodeIndex(IndexBase): def __init__(self): self._file_node_map = {} - # override + @override def update(self, nodes: List[DocNode]) -> None: for node in nodes: if node.group != LAZY_ROOT_NAME: @@ -29,13 +30,13 @@ def update(self, nodes: List[DocNode]) -> None: if file_name: self._file_node_map[file_name] = node - # override + @override def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: # group_name is ignored left = {k: v for k, v in self._file_node_map.items() if v.uid not in uids} self._file_node_map = left - # override + @override def query(self, files: List[str]) -> List[DocNode]: ret = [] for file in files: @@ -329,7 +330,8 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_ index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]: self._lazy_init() - if not index_instance := self.store.get_index(type=index): + index_instance = self.store.get_index(type=index) + if not index_instance: raise NotImplementedError(f"index type '{index}' is not supported currently.") self._dynamic_create_nodes(group_name, self.store) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index 643323d1..795c7d41 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -7,7 +7,9 @@ import numpy as np from .component.bm25 import BM25 from lazyllm import LOG, config, ThreadPoolExecutor +from lazyllm.common import override import pymilvus +from pymilvus.client.abstract import AnnSearchRequest, BaseRanker # ---------------------------------------------------------------------------- # @@ -85,15 +87,15 @@ def wrapper(query, nodes, **kwargs): return decorator(func) if func else decorator - # override + @override def update(self, nodes: List[DocNode]) -> None: pass - # override + @override def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: pass - # override + @override def query( self, query: str, @@ -176,8 +178,9 @@ def register_similarity( class MilvusIndex(IndexBase): class Field: - def __init__(self, name: str, data_type: pymilvus.DataType, index_type: str, - metric_type: str, index_params={}, dim: Optional[int] = None): + def __init__(self, name: str, data_type: pymilvus.DataType, + metric_type: str, index_type: Optional[str] = None, + index_params={}, dim: Optional[int] = None): self.name = name self.data_type = data_type self.index_type = index_type @@ -185,8 +188,7 @@ def __init__(self, name: str, data_type: pymilvus.DataType, index_type: str, self.index_params = index_params self.dim = dim - def __init__(self, embed: Dict[str, Callable], - group_fields: Dict[str, List[MilvusIndex.Field]], + def __init__(self, embed: Dict[str, Callable], group_fields: Dict[str, List[Field]], uri: str, full_data_store: StoreBase): self._embed = embed self._full_data_store = full_data_store @@ -219,7 +221,7 @@ def __init__(self, embed: Dict[str, Callable], self._client.create_collection(collection_name=group_name, schema=schema, index_params=index_params) - # override + @override def update(self, nodes: List[DocNode]) -> None: parallel_do_embedding(self._embed, nodes) for node in nodes: @@ -227,7 +229,7 @@ def update(self, nodes: List[DocNode]) -> None: data[self._primary_key] = node.uid self._client.upsert(collection_name=node.group, data=data) - # override + @override def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: if group_name: self._client.delete(collection_name=group_name, @@ -237,22 +239,22 @@ def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: self._client.delete(collection_name=group_name, filter=f'{self._primary_key} in {uids}') - # override + @override def query(self, - query: str, group_name: str, - embed_keys: Optional[List[str]] = None, - topk: int = 10, + req: AnnSearchRequest, + ranker: BaseRanker, + limit: int = 10, + timeout: Optional[float] = None, **kwargs) -> List[DocNode]: - uids = set() - for embed_name in embed_keys: - embed_func = self._embed.get(embed_name) - query_embedding = embed_func(query) - results = self._client.search(collection_name=group_name, data=[query_embedding], - limit=topk, anns_field=embed_name) - if len(results) > 0: - # we have only one `data` for search() so there is only one result in `results` - for result in results[0]: - uids.update(result['id']) - - return self._full_data_store.get_nodes(group_name, list(uids)) + results = self._client.hybrid_search( + collection_name=group_name, reqs=[req], ranker=ranker, limit=limit, + timeout=timeout) + if len(results) != 1: + raise ValueError(f'return results size [{len(results)}] != 1') + + uids = [] + for record in results[0]: + uids.append(record['id']) + + return self._full_data_store.get_group_nodes(group_name, uids) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 95344c63..66a04ce1 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Optional import chromadb from lazyllm import LOG, config +from lazyllm.common import override from chromadb.api.models.Collection import Collection from .store_base import StoreBase from .index_base import IndexBase @@ -18,20 +19,13 @@ class MapStore(StoreBase): def __init__(self, node_groups: List[str]): + super().__init__() # Dict[group_name, Dict[uuid, DocNode]] self._group2docs: Dict[str, Dict[str, DocNode]] = { group: {} for group in node_groups } - self._name2index = {} - # override - def update_nodes(self, nodes: List[DocNode]) -> None: - for node in nodes: - self._group2docs[node.group][node.uid] = node - - self._update_indices(self._name2index, nodes) - - # override + @override def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: docs = self._group2docs.get(group_name) if not docs: @@ -47,35 +41,29 @@ def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: ret.append(doc) return ret - # override - def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: - if uids: - docs = self._group2docs.get(group_name) - if docs: - self._remove_from_indices(self._name2index, uids) - for uid in uids: - docs.pop(uid, None) - else: - docs = self._group2docs.pop(group_name, None) - if docs: - self._remove_from_indices(self._name2index, [doc.uid for doc in docs]) - - # override + @override def is_group_active(self, name: str) -> bool: docs = self._group2docs.get(name) return True if docs else False - # override + @override def all_groups(self) -> List[str]: return self._group2docs.keys() - # override - def register_index(self, type: str, index: IndexBase) -> None: - self._name2index[type] = index + @override + def _update_nodes(self, nodes: List[DocNode]) -> None: + for node in nodes: + self._group2docs[node.group][node.uid] = node - # override - def get_index(self, type: str) -> Optional[IndexBase]: - return self._name2index.get(type) + @override + def _remove_nodes(self, group_name: str, uids: List[str] = None) -> None: + if uids: + docs = self._group2docs.get(group_name) + if docs: + for uid in uids: + docs.pop(uid, None) + else: + self._group2docs.pop(group_name, None) def find_node_by_uid(self, uid: str) -> Optional[DocNode]: for docs in self._group2docs.values(): @@ -90,6 +78,7 @@ class ChromadbStore(StoreBase): def __init__( self, node_groups: List[str], embed_dim: Dict[str, int] ) -> None: + super().__init__() self._map_store = MapStore(node_groups) self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") @@ -99,38 +88,30 @@ def __init__( } self._embed_dim = embed_dim - # override - def update_nodes(self, nodes: List[DocNode]) -> None: - self._map_store.update_nodes(nodes) - self._save_nodes(nodes) - - # override + @override def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: return self._map_store.get_nodes(group_name, uids) - # override - def remove_nodes(self, group_name: str, uids: List[str]) -> None: - if uids: - self._delete_group_nodes(group_name, uids) - else: - self._db_client.delete_collection(name=group_name) - return self._map_store.remove_nodes(group_name, uids) - - # override + @override def is_group_active(self, name: str) -> bool: return self._map_store.is_group_active(name) - # override + @override def all_groups(self) -> List[str]: return self._map_store.all_groups() - # override - def register_index(self, type: str, index: IndexBase) -> None: - self._map_store.register_index(type, index) + @override + def _update_nodes(self, nodes: List[DocNode]) -> None: + self._map_store.update_nodes(nodes) + self._save_nodes(nodes) - # override - def get_index(self, type: str) -> Optional[IndexBase]: - return self._map_store.get_index(type) + @override + def _remove_nodes(self, group_name: str, uids: List[str]) -> None: + if uids: + self._delete_group_nodes(group_name, uids) + else: + self._db_client.delete_collection(name=group_name) + return self._map_store.remove_nodes(group_name, uids) def _load_store(self) -> None: if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index 9e40f7ff..a1cce5b9 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -42,12 +42,12 @@ def _remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> No pass @staticmethod - def _update_indices(name2index: Dict[str, BaseIndex], nodes: List[DocNode]) -> None: + def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: for _, index in name2index.items(): index.update(nodes) @staticmethod - def _remove_from_indices(name2index: Dict[str, BaseIndex], uids: List[str], + def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], group_name: Optional[str] = None) -> None: for _, index in name2index.items(): index.remove(uids, group_name) diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index d3d02ae4..154e8786 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -78,7 +78,7 @@ def test_insert_dict_as_sparse_embedding(self): node1.uid: [0, 10, 20], node2.uid: [30, 0, 50], } - self.store.add_nodes([node1, node2]) + self.store.update_nodes([node1, node2]) results = self.store._peek_all_documents('group1') nodes = self.store._build_nodes_from_chroma(results) From 8befadcac3a5a5058064bd6333a977b4de92f22e Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Wed, 30 Oct 2024 16:47:05 +0800 Subject: [PATCH 10/60] s --- lazyllm/tools/rag/index.py | 85 ----------- lazyllm/tools/rag/store.py | 255 +++++++++++++++++++++++++++++--- lazyllm/tools/rag/store_base.py | 32 +--- 3 files changed, 238 insertions(+), 134 deletions(-) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index 795c7d41..f482c0f8 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -173,88 +173,3 @@ def register_similarity( batch: bool = False, ) -> Callable: return DefaultIndex.register_similarity(func, mode, descend, batch) - -# ---------------------------------------------------------------------------- # - -class MilvusIndex(IndexBase): - class Field: - def __init__(self, name: str, data_type: pymilvus.DataType, - metric_type: str, index_type: Optional[str] = None, - index_params={}, dim: Optional[int] = None): - self.name = name - self.data_type = data_type - self.index_type = index_type - self.metric_type = metric_type - self.index_params = index_params - self.dim = dim - - def __init__(self, embed: Dict[str, Callable], group_fields: Dict[str, List[Field]], - uri: str, full_data_store: StoreBase): - self._embed = embed - self._full_data_store = full_data_store - - self._primary_key = 'uid' - self._client = pymilvus.MilvusClient(uri=uri) - - for group_name, field_list in group_fields.items(): - if group_name in self._client.list_collections(): - continue - - schema = self._client.create_schema(auto_id=False, enable_dynamic_field=False) - schema.add_field( - field_name=self._primary_key, - datatype=pymilvus.DataType.VARCHAR, - max_length=128, - is_primary=True, - ) - for field in field_list: - schema.add_field( - field_name=field.name, - datatype=field.data_type, - dim=field.dim) - - index_params = self._client.prepare_index_params() - for field in field_list: - index_params.add_index(field_name=field.name, index_type=field.index_type, - metric_type=field.metric_type, params=field.index_params) - - self._client.create_collection(collection_name=group_name, schema=schema, - index_params=index_params) - - @override - def update(self, nodes: List[DocNode]) -> None: - parallel_do_embedding(self._embed, nodes) - for node in nodes: - data = node.embedding.copy() - data[self._primary_key] = node.uid - self._client.upsert(collection_name=node.group, data=data) - - @override - def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - if group_name: - self._client.delete(collection_name=group_name, - filter=f'{self._primary_key} in {uids}') - else: - for group_name in self._client.list_collections(): - self._client.delete(collection_name=group_name, - filter=f'{self._primary_key} in {uids}') - - @override - def query(self, - group_name: str, - req: AnnSearchRequest, - ranker: BaseRanker, - limit: int = 10, - timeout: Optional[float] = None, - **kwargs) -> List[DocNode]: - results = self._client.hybrid_search( - collection_name=group_name, reqs=[req], ranker=ranker, limit=limit, - timeout=timeout) - if len(results) != 1: - raise ValueError(f'return results size [{len(results)}] != 1') - - uids = [] - for record in results[0]: - uids.append(record['id']) - - return self._full_data_store.get_group_nodes(group_name, uids) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 66a04ce1..587886e2 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -17,13 +17,42 @@ # ---------------------------------------------------------------------------- # -class MapStore(StoreBase): +def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: + for _, index in name2index.items(): + index.update(nodes) + +def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], + group_name: Optional[str] = None) -> None: + for _, index in name2index.items(): + index.remove(uids, group_name) + +class MapStore(StoreBase, IndexBase): def __init__(self, node_groups: List[str]): super().__init__() # Dict[group_name, Dict[uuid, DocNode]] self._group2docs: Dict[str, Dict[str, DocNode]] = { group: {} for group in node_groups } + self._name2index = {} + + @override + def update_nodes(self, nodes: List[DocNode]) -> None: + for node in nodes: + self._group2docs[node.group][node.uid] = node + _update_indices(self._name2index, nodes) + + @override + def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: + if uids: + docs = self._group2docs.get(group_name) + if docs: + _remove_from_indices(self._name2index, uids) + for uid in uids: + docs.pop(uid, None) + else: + docs = self._group2docs.pop(group_name, None) + if docs: + _remove_from_indices(self._name2index, [doc.uid for doc in docs]) @override def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: @@ -51,19 +80,32 @@ def all_groups(self) -> List[str]: return self._group2docs.keys() @override - def _update_nodes(self, nodes: List[DocNode]) -> None: - for node in nodes: - self._group2docs[node.group][node.uid] = node + def register_index(self, type: str, index: IndexBase) -> None: + self._name2index[type] = index @override - def _remove_nodes(self, group_name: str, uids: List[str] = None) -> None: - if uids: - docs = self._group2docs.get(group_name) - if docs: + def get_index(self, Optional[type]: str = None) -> Optional[IndexBase]: + if type: + return self._name2index.get(type) + return self + + @override + def update(nodes: List[DocNode]) -> None: + self.update_nodes(nodes) + + @override + def remove(uids: List[str], group_name: Optional[str] = None) -> None: + if group_name: + self.remove_nodes(group_name, uids) + else: + for _, docs in self._group2docs.items(): for uid in uids: docs.pop(uid, None) - else: - self._group2docs.pop(group_name, None) + _remove_from_indices(self._name2index, uids) + + @override + def query(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: + return self.get_nodes(group_name, uids) def find_node_by_uid(self, uid: str) -> Optional[DocNode]: for docs in self._group2docs.values(): @@ -88,6 +130,19 @@ def __init__( } self._embed_dim = embed_dim + @override + def update_nodes(self, nodes: List[DocNode]) -> None: + self._map_store.update_nodes(nodes) + self._save_nodes(nodes) + + @override + def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: + if uids: + self._delete_group_nodes(group_name, uids) + else: + self._db_client.delete_collection(name=group_name) + return self._map_store.remove_nodes(group_name, uids) + @override def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: return self._map_store.get_nodes(group_name, uids) @@ -100,19 +155,6 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: return self._map_store.all_groups() - @override - def _update_nodes(self, nodes: List[DocNode]) -> None: - self._map_store.update_nodes(nodes) - self._save_nodes(nodes) - - @override - def _remove_nodes(self, group_name: str, uids: List[str]) -> None: - if uids: - self._delete_group_nodes(group_name, uids) - else: - self._db_client.delete_collection(name=group_name) - return self._map_store.remove_nodes(group_name, uids) - def _load_store(self) -> None: if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: LOG.info("No persistent data found, skip the rebuilding phrase.") @@ -213,3 +255,170 @@ def _peek_all_documents(self, group: str) -> Dict[str, List]: assert group in self._collections, f"group {group} not found." collection = self._collections[group] return collection.peek(collection.count()) + +# ---------------------------------------------------------------------------- # + +class MilvusStore(StoreBase, IndexBase): + def __init__(self, uri: str, embed: Dict[str, Callable], + group_fields: Dict[str, List[pymilvus.FieldSchema]], + group_indices: Dict[str, pymilvus.IndexParams]): + self._primary_key = 'uid' + self._embedding_keys = list(embed.keys()) + self._metadata_keys = filter(lambda x: x not in embed.keys(), group_fields.keys()) + + self._embed = embed + self._client = pymilvus.MilvusClient(uri=uri) + + id_field = pymilvus.FieldSchema( + name=self._primary_key, dtype=pymilvus.DataType.VARCHAR, + max_length=128, is_primary=True) + + for group_name, field_list in group_fields.items(): + if group_name in self._client.list_collections(): + continue + + schema = CollectionSchema(fields=id_field+field_list) + index_params = group_indices.get(group_name) + + self._client.create_collection(collection_name=group_name, schema=schema, + index_params=index_params) + + self._map_store = MapStore(list(group_fields.keys())) + self._load_all_nodes_to(self._map_store) + + # ----- Store APIs ----- # + + @override + def update_nodes(self, nodes: List[DocNode]) -> None: + parallel_do_embedding(self._embed, nodes) + for node in nodes: + data = self._serialize_node_partial(node) + self._client.upsert(collection_name=node.group, data=data) + + self._map_store.update_nodes(nodes) + + @override + def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: + if uids: + self._client.delete(collection_name=group_name, + filter=f'{self._primary_key} in {uids}') + else: + self._client.drop_collection(collection_name=group_name) + + self._map_store.remove_nodes(group_name, uids) + + @override + def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: + return self._map_store.get_nodes(group_name, uids) + + @override + def is_group_active(self, name: str) -> bool: + return self._map_store.is_group_active(name) + + @override + def all_groups(self) -> List[str]: + return _map_store.all_groups() + + @override + def register_index(self, type: str, index: IndexBase) -> None: + self._map_store.register_index(type, index) + + @override + def get_index(self, Optional[type]: str = None) -> Optional[IndexBase]: + if type: + return self._map_store.get_index(type) + return self + + # ----- Index APIs ----- # + + @override + def update(self, nodes: List[DocNode]) -> None: + self.update_nodes(nodes) + + @override + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + self.remove_nodes(group_name, uids) + + @override + def query(self, + query: str, + group_name: str, + similarity_name: str, + similarity_cut_off: Union[float, Dict[str, float]], + topk: int, + embed_keys: Optional[List[str]] = None, + **kwargs) -> List[DocNode]: + reqs = [] + for key in embed_keys: + embed_func = self._embed.get(key) + query_embedding = embed_func(query) + # TODO set search params + req = AnnSearchRequest( + data=query_embedding, + anns_field=key, + limit=topk, + ) + reqs.append(req) + + results = self._client.hybrid_search(collection_name=group_name, reqs=reqs, + ranker=ranker, limit=topk) + if len(results) != 1: + raise ValueError(f'return results size [{len(results)}] != 1') + + uidset = set() + for record in results[0]: + uidset.insert(record['id']) + return self._map_store.get_nodes(group_name, list(uidset)) + + def _load_all_nodes_to(self, store: StoreBase): + results = self._client.query(collection_name=group_name, + filter=f'{self._primary_key} != ""') + for result in results: + doc = self._deserialize_node_partial(result) + store.update_nodes([doc], group) + + # construct DocNode::parent and DocNode::children + for group in all_groups(): + for node in self.get_nodes(group): + if node.parent: + parent_uid = node.parent + parent_node = self._map_store.find_node_by_uid(parent_uid) + node.parent = parent_node + parent_node.children[node.group].append(node) + + @staticmethod + def _serialize_node_partial(node: DocNode) -> Dict: + res = { + 'uid': node.uid, + 'text': node.text, + 'group': node.group, + } + + if self.parent: + res['parent'] = node.parent.uid + + for k, v in node.embedding.items(): + res['embedding_' + k] = v + for k, v in self.metadata.items(): + res['metadata_' + k] = json.dumps(v) + + return res + + @staticmethod + def _deserialize_node_partial(result: Dict) -> DocNode: + ''' + without parent and children + ''' + doc = DocNode( + uid=result.get('uid'), + text=result.get('text'), + group=result.get('group'), + parent=result.get('parent'), # this is the parent's uid + ) + + for k in self._embedding_keys: + doc.embedding[k] = result.get('embedding_' + k) + for k in self._metadata_keys: + doc._metadata[k] = result.get('metadata_' + k) + + return doc diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index a1cce5b9..6d0184e2 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -4,22 +4,13 @@ from .index_base import IndexBase class StoreBase(ABC): - def __init__(self): - self._name2index = {} - - def register_index(self, type: str, index: IndexBase) -> None: - self._name2index[type] = index - - def get_index(self, type: str) -> Optional[IndexBase]: - return self._name2index.get(type) - + @abstractmethod def update_nodes(self, nodes: List[DocNode]) -> None: - self._update_nodes(nodes) - self._update_indices(self._name2index, nodes) + pass + @abstractmethod def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: - self._remove_nodes(group_name, uids) - self._remove_from_indices(self._name2index, uids, group_name) + pass @abstractmethod def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: @@ -34,20 +25,9 @@ def all_groups(self) -> List[str]: pass @abstractmethod - def _update_nodes(self, nodes: List[DocNode]) -> None: + def register_index(self, type: str, index: IndexBase) -> None: pass @abstractmethod - def _remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: + def get_index(self, Optional[type]: str = None) -> Optional[IndexBase]: pass - - @staticmethod - def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: - for _, index in name2index.items(): - index.update(nodes) - - @staticmethod - def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], - group_name: Optional[str] = None) -> None: - for _, index in name2index.items(): - index.remove(uids, group_name) From 93e24a478178bf63d4d223e1dce3a3698fb4809a Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Wed, 30 Oct 2024 16:59:28 +0800 Subject: [PATCH 11/60] s --- lazyllm/tools/rag/store.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 587886e2..f9854316 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -344,7 +344,7 @@ def query(self, query: str, group_name: str, similarity_name: str, - similarity_cut_off: Union[float, Dict[str, float]], + similarity_cut_off: Union[float, Dict[str, float]], # ignored topk: int, embed_keys: Optional[List[str]] = None, **kwargs) -> List[DocNode]: @@ -352,7 +352,7 @@ def query(self, for key in embed_keys: embed_func = self._embed.get(key) query_embedding = embed_func(query) - # TODO set search params + # TODO set search params according to similarity_name req = AnnSearchRequest( data=query_embedding, anns_field=key, @@ -400,7 +400,7 @@ def _serialize_node_partial(node: DocNode) -> Dict: for k, v in node.embedding.items(): res['embedding_' + k] = v for k, v in self.metadata.items(): - res['metadata_' + k] = json.dumps(v) + res['metadata_' + k] = v return res From 803374a828aba91a3235d192d66f32030eecbb0c Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Wed, 30 Oct 2024 17:07:26 +0800 Subject: [PATCH 12/60] s --- lazyllm/tools/rag/doc_impl.py | 13 ++++++++++++- lazyllm/tools/rag/store.py | 10 +++++----- lazyllm/tools/rag/store_base.py | 2 +- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 268bc4e6..a26a7826 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -51,6 +51,17 @@ def _create_file_node_index(store) -> _FileNodeIndex: index.update(store.get_nodes(group)) return index + @staticmethod + def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: + for _, index in name2index.items(): + index.update(nodes) + + @staticmethod + def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], + group_name: Optional[str] = None) -> None: + for _, index in name2index.items(): + index.remove(uids, group_name) + def _create_some_indices(self): if not self._store.get_index(type='file_node_map'): self.register_index(type='file_node_map', index=self._create_file_node_index(self._store)) @@ -80,7 +91,7 @@ def all_groups(self) -> List[str]: def register_index(self, type: str, index: IndexBase) -> None: self._extra_indices[type] = index - def get_index(self, type: str) -> Optional[IndexBase]: + def get_index(self, type: str = 'default') -> Optional[IndexBase]: index = self._extra_indices.get(type) if not index: index = self._store.get_index(type) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index f9854316..7615b904 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -84,8 +84,8 @@ def register_index(self, type: str, index: IndexBase) -> None: self._name2index[type] = index @override - def get_index(self, Optional[type]: str = None) -> Optional[IndexBase]: - if type: + def get_index(self, Optional[type]: str = 'default') -> Optional[IndexBase]: + if type != 'default': return self._name2index.get(type) return self @@ -324,8 +324,8 @@ def register_index(self, type: str, index: IndexBase) -> None: self._map_store.register_index(type, index) @override - def get_index(self, Optional[type]: str = None) -> Optional[IndexBase]: - if type: + def get_index(self, Optional[type]: str = 'default') -> Optional[IndexBase]: + if type != 'default': return self._map_store.get_index(type) return self @@ -361,7 +361,7 @@ def query(self, reqs.append(req) results = self._client.hybrid_search(collection_name=group_name, reqs=reqs, - ranker=ranker, limit=topk) + ranker=ranker, limit=topk, **kwargs) if len(results) != 1: raise ValueError(f'return results size [{len(results)}] != 1') diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index 6d0184e2..6c1c14f2 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -29,5 +29,5 @@ def register_index(self, type: str, index: IndexBase) -> None: pass @abstractmethod - def get_index(self, Optional[type]: str = None) -> Optional[IndexBase]: + def get_index(self, Optional[type]: str = 'default') -> Optional[IndexBase]: pass From bf03c15bd4d68b290c17669369ad2eb0b5928d3f Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Wed, 30 Oct 2024 23:41:33 +0800 Subject: [PATCH 13/60] s --- lazyllm/tools/rag/index.py | 2 - lazyllm/tools/rag/store.py | 146 +++++++++++++++++++++++--------- lazyllm/tools/rag/store_base.py | 4 +- tests/basic_tests/test_index.py | 51 ++++++----- 4 files changed, 131 insertions(+), 72 deletions(-) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index f482c0f8..4c6ecdce 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -8,8 +8,6 @@ from .component.bm25 import BM25 from lazyllm import LOG, config, ThreadPoolExecutor from lazyllm.common import override -import pymilvus -from pymilvus.client.abstract import AnnSearchRequest, BaseRanker # ---------------------------------------------------------------------------- # diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 7615b904..bdeab03b 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -1,12 +1,18 @@ -from typing import Any, Dict, List, Optional +import copy +from typing import Any, Dict, List, Optional, Callable, Union import chromadb from lazyllm import LOG, config from lazyllm.common import override from chromadb.api.models.Collection import Collection from .store_base import StoreBase from .index_base import IndexBase +from .index import parallel_do_embedding from .doc_node import DocNode import json +import pymilvus +from pymilvus import MilvusClient, FieldSchema, CollectionSchema +from pymilvus.milvus_client.index import IndexParams +from pymilvus.client.abstract import AnnSearchRequest, BaseRanker # ---------------------------------------------------------------------------- # @@ -84,17 +90,17 @@ def register_index(self, type: str, index: IndexBase) -> None: self._name2index[type] = index @override - def get_index(self, Optional[type]: str = 'default') -> Optional[IndexBase]: + def get_index(self, type: str = 'default') -> Optional[IndexBase]: if type != 'default': return self._name2index.get(type) return self @override - def update(nodes: List[DocNode]) -> None: + def update(self, nodes: List[DocNode]) -> None: self.update_nodes(nodes) @override - def remove(uids: List[str], group_name: Optional[str] = None) -> None: + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: if group_name: self.remove_nodes(group_name, uids) else: @@ -258,34 +264,90 @@ def _peek_all_documents(self, group: str) -> Dict[str, List]: # ---------------------------------------------------------------------------- # +class MilvusField: + DTYPE_VARCHAR = 0 + DTYPE_FLOAT_VECTOR = 1 + DTYPE_SPARSE_FLOAT_VECTOR = 2 + + def __init__(self, name: str, data_type: int, index_type: Optional[str] = None, + metric_type: Optional[str] = None, index_params: Optional[Dict] = None, + max_length: Optional[int] = None): + self.name = name + self.data_type = data_type + self.index_type = index_type + self.metric_type = metric_type + self.index_params = index_params + self.max_length = max_length + class MilvusStore(StoreBase, IndexBase): + _type2milvus = [ + pymilvus.DataType.VARCHAR, # DTYPE_VARCHAR + pymilvus.DataType.FLOAT_VECTOR, # DTYPE_FLOAT_VECTOR + pymilvus.DataType.SPARSE_FLOAT_VECTOR, # DTYPE_SPARSE_FLOAT_VECTOR + ] + def __init__(self, uri: str, embed: Dict[str, Callable], - group_fields: Dict[str, List[pymilvus.FieldSchema]], - group_indices: Dict[str, pymilvus.IndexParams]): + # a field is either an embedding key or a metadata key + group_fields: Dict[str, List[MilvusField]]): self._primary_key = 'uid' - self._embedding_keys = list(embed.keys()) - self._metadata_keys = filter(lambda x: x not in embed.keys(), group_fields.keys()) - + self._embedding_keys = embed.keys() self._embed = embed - self._client = pymilvus.MilvusClient(uri=uri) - - id_field = pymilvus.FieldSchema( - name=self._primary_key, dtype=pymilvus.DataType.VARCHAR, - max_length=128, is_primary=True) + self._client = MilvusClient(uri=uri) + + embed_dim = {k: len(e('a')) for k, e in embed.items()} + builtin_fields = [ + FieldSchema(name=self._primary_key, dtype=pymilvus.DataType.VARCHAR, + max_length=128, is_primary=True), + FieldSchema(name='text', dtype=pymilvus.DataType.VARCHAR, + max_length=65535), + FieldSchema(name='group', dtype=pymilvus.DataType.VARCHAR, + max_length=256), + FieldSchema(name='parent', dtype=pymilvus.DataType.VARCHAR, + max_length=256), + ] for group_name, field_list in group_fields.items(): if group_name in self._client.list_collections(): continue - schema = CollectionSchema(fields=id_field+field_list) - index_params = group_indices.get(group_name) - + index_params = IndexParams() + field_schema_list = copy.copy(builtin_fields) + + for field in field_list: + field_schema = None + if field.name in self._embedding_keys: + field_schema = FieldSchema( + name=self._gen_embedding_key(field.name), + dtype=self._type2milvus[field.data_type], + dim=embed_dim.get(field.name)) + else: + field_schema = FieldSchema( + name=self._gen_metadata_key(field.name), + dtype=self._type2milvus[field.data_type], + max_length=field.max_length) + field_schema_list.append(field_schema) + + if field_schema.index_type is not None: + index_params.add_index(field_name=field_schema.name, + index_type=field.index_type, + metric_type=field.metric_type, + params=field.index_params) + + schema = CollectionSchema(fields=field_schema_list) self._client.create_collection(collection_name=group_name, schema=schema, index_params=index_params) self._map_store = MapStore(list(group_fields.keys())) self._load_all_nodes_to(self._map_store) + @staticmethod + def _gen_embedding_key(k: str) -> str: + return 'embedding_' + k + + @staticmethod + def _gen_metadata_key(k: str) -> str: + return 'metadata_' + k + # ----- Store APIs ----- # @override @@ -317,14 +379,14 @@ def is_group_active(self, name: str) -> bool: @override def all_groups(self) -> List[str]: - return _map_store.all_groups() + return self._map_store.all_groups() @override def register_index(self, type: str, index: IndexBase) -> None: self._map_store.register_index(type, index) @override - def get_index(self, Optional[type]: str = 'default') -> Optional[IndexBase]: + def get_index(self, type: str = 'default') -> Optional[IndexBase]: if type != 'default': return self._map_store.get_index(type) return self @@ -343,9 +405,9 @@ def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: def query(self, query: str, group_name: str, - similarity_name: str, - similarity_cut_off: Union[float, Dict[str, float]], # ignored - topk: int, + similarity: str = "dummy", + similarity_cut_off: Union[float, Dict[str, float]] = float("-inf"), + topk: int = 10, embed_keys: Optional[List[str]] = None, **kwargs) -> List[DocNode]: reqs = [] @@ -354,14 +416,15 @@ def query(self, query_embedding = embed_func(query) # TODO set search params according to similarity_name req = AnnSearchRequest( - data=query_embedding, - anns_field=key, + data=[query_embedding], + anns_field=self._gen_embedding_key(key), limit=topk, + param={}, ) reqs.append(req) results = self._client.hybrid_search(collection_name=group_name, reqs=reqs, - ranker=ranker, limit=topk, **kwargs) + ranker=BaseRanker(), limit=topk, **kwargs) if len(results) != 1: raise ValueError(f'return results size [{len(results)}] != 1') @@ -371,14 +434,15 @@ def query(self, return self._map_store.get_nodes(group_name, list(uidset)) def _load_all_nodes_to(self, store: StoreBase): - results = self._client.query(collection_name=group_name, - filter=f'{self._primary_key} != ""') - for result in results: - doc = self._deserialize_node_partial(result) - store.update_nodes([doc], group) + for group_name in self._client.list_collections(): + results = self._client.query(collection_name=group_name, + filter=f'{self._primary_key} != ""') + for result in results: + doc = self._deserialize_node_partial(result) + store.update_nodes([doc], group_name) # construct DocNode::parent and DocNode::children - for group in all_groups(): + for group in self.all_groups(): for node in self.get_nodes(group): if node.parent: parent_uid = node.parent @@ -386,26 +450,26 @@ def _load_all_nodes_to(self, store: StoreBase): node.parent = parent_node parent_node.children[node.group].append(node) - @staticmethod - def _serialize_node_partial(node: DocNode) -> Dict: + def _serialize_node_partial(self, node: DocNode) -> Dict: res = { 'uid': node.uid, 'text': node.text, 'group': node.group, } - if self.parent: + if node.parent: res['parent'] = node.parent.uid + else: + res['parent'] = '' for k, v in node.embedding.items(): - res['embedding_' + k] = v - for k, v in self.metadata.items(): - res['metadata_' + k] = v + res[self._gen_embedding_key(k)] = v + for k, v in node.metadata.items(): + res[self._gen_metadata_key(k)] = v return res - @staticmethod - def _deserialize_node_partial(result: Dict) -> DocNode: + def _deserialize_node_partial(self, result: Dict) -> DocNode: ''' without parent and children ''' @@ -417,8 +481,8 @@ def _deserialize_node_partial(result: Dict) -> DocNode: ) for k in self._embedding_keys: - doc.embedding[k] = result.get('embedding_' + k) + doc.embedding[k] = result.get(self._gen_embedding_key(k)) for k in self._metadata_keys: - doc._metadata[k] = result.get('metadata_' + k) + doc._metadata[k] = result.get(self._gen_metadata_key(k)) return doc diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index 6c1c14f2..81275c4b 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, List, Dict +from typing import Optional, List from .doc_node import DocNode from .index_base import IndexBase @@ -29,5 +29,5 @@ def register_index(self, type: str, index: IndexBase) -> None: pass @abstractmethod - def get_index(self, Optional[type]: str = 'default') -> Optional[IndexBase]: + def get_index(self, type: str = 'default') -> Optional[IndexBase]: pass diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index 2538d9cc..843aac9d 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -2,16 +2,11 @@ import time import unittest import tempfile -import pymilvus from unittest.mock import MagicMock -from lazyllm.tools.rag.store import MapStore, LAZY_ROOT_NAME +from lazyllm.tools.rag.store import MapStore, LAZY_ROOT_NAME, MilvusStore, MilvusField from lazyllm.tools.rag.doc_node import DocNode -from lazyllm.tools.rag.index import ( - DefaultIndex, - register_similarity, - MilvusIndex, MilvusEmbeddingField, - parallel_do_embedding) - +from lazyllm.tools.rag.index import DefaultIndex, register_similarity, parallel_do_embedding +import pymilvus class TestDefaultIndex(unittest.TestCase): def setUp(self): @@ -109,15 +104,17 @@ def test_query_multi_embed_one_thresholds(self): class TestMilvusIndex(unittest.TestCase): def setUp(self): - embedding_fields = [ - MilvusEmbeddingField(name="vec1", dim=3, data_type=pymilvus.DataType.FLOAT_VECTOR, - index_type="HNSW", metric_type="IP"), - MilvusEmbeddingField(name="vec2", dim=5, data_type=pymilvus.DataType.FLOAT_VECTOR, - index_type="HNSW", metric_type="IP"), + field_list = [ + MilvusField(name="comment", data_type=MilvusField.DTYPE_VARCHAR, max_length=128, + index_type='Trie'), + MilvusField(name="vec1", data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='IP'), + MilvusField(name="vec2", data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='IP'), ] - group_embedding_fields = { - "group1": embedding_fields, - "group2": embedding_fields, + group_fields = { + "group1": field_list, + "group2": field_list, } self.mock_embed = { @@ -128,38 +125,38 @@ def setUp(self): self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] _, self.store_file = tempfile.mkstemp(suffix=".db") - self.map_store = MapStore(self.node_groups) - self.index = MilvusIndex(embed=self.mock_embed, - group_embedding_fields=group_embedding_fields, - uri=self.store_file, full_data_store=self.map_store) - self.map_store.register_index(type='milvus', index=self.index) + self.store = MilvusStore(uri=self.store_file, embed=self.mock_embed, + group_fields=group_fields) + self.index = self.store.get_index() self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, - embedding={"vec1": [1.0, 2.0, 3.0], "vec2": [4.0, 5.0, 6.0, 7.0, 8.0]}) + embedding={"vec1": [1.0, 2.0, 3.0], "vec2": [4.0, 5.0, 6.0, 7.0, 8.0]}, + metadata={'comment': 'comment1'}) self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, - embedding={"vec1": [100.0, 200.0, 300.0], "vec2": [400.0, 500.0, 600.0, 700.0, 800.0]}) + embedding={"vec1": [100.0, 200.0, 300.0], "vec2": [400.0, 500.0, 600.0, 700.0, 800.0]}, + metadata={'comment': 'comment2'}) def tearDown(self): os.remove(self.store_file) def test_update_and_query(self): - self.map_store.update_nodes([self.node1]) + self.store.update_nodes([self.node1]) ret = self.index.query(query='text1', group_name='group1', embed_keys=['vec2'], topk=1) self.assertEqual(len(ret), 1) self.assertEqual(ret[0].uid, self.node1.uid) - self.map_store.update_nodes([self.node2]) + self.store.update_nodes([self.node2]) ret = self.index.query(query='text2', group_name='group1', embed_keys=['vec2'], topk=1) self.assertEqual(len(ret), 1) self.assertEqual(ret[0].uid, self.node2.uid) def test_remove_and_query(self): - self.map_store.update_nodes([self.node1, self.node2]) + self.store.update_nodes([self.node1, self.node2]) ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) self.assertEqual(len(ret), 1) self.assertEqual(ret[0].uid, self.node2.uid) - self.map_store.remove_nodes("group1", [self.node2.uid]) + self.store.remove_nodes("group1", [self.node2.uid]) ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) self.assertEqual(len(ret), 1) self.assertEqual(ret[0].uid, self.node1.uid) From 138c149fa152200aa9f85a4b28b8bd0df2c6caad Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 31 Oct 2024 00:28:18 +0800 Subject: [PATCH 14/60] s --- lazyllm/tools/rag/store.py | 29 +++++++++-------------------- tests/basic_tests/test_index.py | 7 +++---- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index bdeab03b..af350553 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -12,7 +12,6 @@ import pymilvus from pymilvus import MilvusClient, FieldSchema, CollectionSchema from pymilvus.milvus_client.index import IndexParams -from pymilvus.client.abstract import AnnSearchRequest, BaseRanker # ---------------------------------------------------------------------------- # @@ -405,32 +404,22 @@ def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: def query(self, query: str, group_name: str, - similarity: str = "dummy", - similarity_cut_off: Union[float, Dict[str, float]] = float("-inf"), + similarity: Optional[str] = None, + similarity_cut_off: Optional[Union[float, Dict[str, float]]] = None, topk: int = 10, embed_keys: Optional[List[str]] = None, **kwargs) -> List[DocNode]: - reqs = [] + uidset = set() for key in embed_keys: embed_func = self._embed.get(key) query_embedding = embed_func(query) - # TODO set search params according to similarity_name - req = AnnSearchRequest( - data=[query_embedding], - anns_field=self._gen_embedding_key(key), - limit=topk, - param={}, - ) - reqs.append(req) + results = self._client.search(collection_name=group_name, data=[query_embedding], + limit=topk, anns_field=self._gen_embedding_key(key)) + if len(results) > 0: + # we have only one `data` for search() so there is only one result in `results` + for result in results[0]: + uidset.update(result['id']) - results = self._client.hybrid_search(collection_name=group_name, reqs=reqs, - ranker=BaseRanker(), limit=topk, **kwargs) - if len(results) != 1: - raise ValueError(f'return results size [{len(results)}] != 1') - - uidset = set() - for record in results[0]: - uidset.insert(record['id']) return self._map_store.get_nodes(group_name, list(uidset)) def _load_all_nodes_to(self, store: StoreBase): diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index 843aac9d..aec13113 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -105,12 +105,11 @@ def test_query_multi_embed_one_thresholds(self): class TestMilvusIndex(unittest.TestCase): def setUp(self): field_list = [ - MilvusField(name="comment", data_type=MilvusField.DTYPE_VARCHAR, max_length=128, - index_type='Trie'), + MilvusField(name="comment", data_type=MilvusField.DTYPE_VARCHAR, max_length=128), MilvusField(name="vec1", data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='IP'), + index_type='HNSW', metric_type='COSINE'), MilvusField(name="vec2", data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='IP'), + index_type='HNSW', metric_type='COSINE'), ] group_fields = { "group1": field_list, From 9eb071dc90b7fdabf81dd7dcf3ad69265ad0bc6b Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 31 Oct 2024 10:27:45 +0800 Subject: [PATCH 15/60] s --- lazyllm/tools/rag/store.py | 32 +++++++++++++++++--------------- tests/basic_tests/test_index.py | 2 +- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index af350553..b7cb2a3b 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -11,7 +11,6 @@ import json import pymilvus from pymilvus import MilvusClient, FieldSchema, CollectionSchema -from pymilvus.milvus_client.index import IndexParams # ---------------------------------------------------------------------------- # @@ -269,7 +268,7 @@ class MilvusField: DTYPE_SPARSE_FLOAT_VECTOR = 2 def __init__(self, name: str, data_type: int, index_type: Optional[str] = None, - metric_type: Optional[str] = None, index_params: Optional[Dict] = None, + metric_type: Optional[str] = "", index_params: Dict = {}, max_length: Optional[int] = None): self.name = name self.data_type = data_type @@ -299,8 +298,6 @@ def __init__(self, uri: str, embed: Dict[str, Callable], max_length=128, is_primary=True), FieldSchema(name='text', dtype=pymilvus.DataType.VARCHAR, max_length=65535), - FieldSchema(name='group', dtype=pymilvus.DataType.VARCHAR, - max_length=256), FieldSchema(name='parent', dtype=pymilvus.DataType.VARCHAR, max_length=256), ] @@ -309,7 +306,7 @@ def __init__(self, uri: str, embed: Dict[str, Callable], if group_name in self._client.list_collections(): continue - index_params = IndexParams() + index_params = self._client.prepare_index_params() field_schema_list = copy.copy(builtin_fields) for field in field_list: @@ -326,7 +323,7 @@ def __init__(self, uri: str, embed: Dict[str, Callable], max_length=field.max_length) field_schema_list.append(field_schema) - if field_schema.index_type is not None: + if field.index_type is not None: index_params.add_index(field_name=field_schema.name, index_type=field.index_type, metric_type=field.metric_type, @@ -354,7 +351,7 @@ def update_nodes(self, nodes: List[DocNode]) -> None: parallel_do_embedding(self._embed, nodes) for node in nodes: data = self._serialize_node_partial(node) - self._client.upsert(collection_name=node.group, data=data) + self._client.upsert(collection_name=node.group, data=[data]) self._map_store.update_nodes(nodes) @@ -415,10 +412,12 @@ def query(self, query_embedding = embed_func(query) results = self._client.search(collection_name=group_name, data=[query_embedding], limit=topk, anns_field=self._gen_embedding_key(key)) - if len(results) > 0: - # we have only one `data` for search() so there is only one result in `results` - for result in results[0]: - uidset.update(result['id']) + # we have only one `data` for search() so there is only one result in `results` + if len(results) != 1: + raise ValueError(f'number of results [{len(results)}] != expected [1]') + + for result in results[0]: + uidset.update(result['id']) return self._map_store.get_nodes(group_name, list(uidset)) @@ -428,6 +427,7 @@ def _load_all_nodes_to(self, store: StoreBase): filter=f'{self._primary_key} != ""') for result in results: doc = self._deserialize_node_partial(result) + doc.group = group_name store.update_nodes([doc], group_name) # construct DocNode::parent and DocNode::children @@ -443,7 +443,6 @@ def _serialize_node_partial(self, node: DocNode) -> Dict: res = { 'uid': node.uid, 'text': node.text, - 'group': node.group, } if node.parent: @@ -465,13 +464,16 @@ def _deserialize_node_partial(self, result: Dict) -> DocNode: doc = DocNode( uid=result.get('uid'), text=result.get('text'), - group=result.get('group'), parent=result.get('parent'), # this is the parent's uid ) for k in self._embedding_keys: - doc.embedding[k] = result.get(self._gen_embedding_key(k)) + val = result.get(self._gen_embedding_key(k)) + if val: + doc.embedding[k] = val for k in self._metadata_keys: - doc._metadata[k] = result.get(self._gen_metadata_key(k)) + val = result.get(self._gen_metadata_key(k)) + if val: + doc._metadata[k] = val return doc diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index aec13113..42c7c16c 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -129,7 +129,7 @@ def setUp(self): self.index = self.store.get_index() self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, - embedding={"vec1": [1.0, 2.0, 3.0], "vec2": [4.0, 5.0, 6.0, 7.0, 8.0]}, + embedding={"vec1": [8.0, 9.0, 10.0], "vec2": [11.0, 12.0, 13.0, 14.0, 15.0]}, metadata={'comment': 'comment1'}) self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, embedding={"vec1": [100.0, 200.0, 300.0], "vec2": [400.0, 500.0, 600.0, 700.0, 800.0]}, From d92379a66b6a3f8f8035202a2a53927a203d5684 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 31 Oct 2024 10:50:25 +0800 Subject: [PATCH 16/60] s --- lazyllm/tools/rag/store.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index b7cb2a3b..ec32bf03 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -336,14 +336,6 @@ def __init__(self, uri: str, embed: Dict[str, Callable], self._map_store = MapStore(list(group_fields.keys())) self._load_all_nodes_to(self._map_store) - @staticmethod - def _gen_embedding_key(k: str) -> str: - return 'embedding_' + k - - @staticmethod - def _gen_metadata_key(k: str) -> str: - return 'metadata_' + k - # ----- Store APIs ----- # @override @@ -421,6 +413,16 @@ def query(self, return self._map_store.get_nodes(group_name, list(uidset)) + # ----- internal helper functions ----- # + + @staticmethod + def _gen_embedding_key(k: str) -> str: + return 'embedding_' + k + + @staticmethod + def _gen_metadata_key(k: str) -> str: + return 'metadata_' + k + def _load_all_nodes_to(self, store: StoreBase): for group_name in self._client.list_collections(): results = self._client.query(collection_name=group_name, From 25fd4f9cd3ac15e4344c335d3881a032642d1f0c Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 31 Oct 2024 15:08:19 +0800 Subject: [PATCH 17/60] s --- lazyllm/tools/rag/embed_utils.py | 37 ++++ lazyllm/tools/rag/index.py | 61 ++---- lazyllm/tools/rag/index_base.py | 4 +- lazyllm/tools/rag/map_backend.py | 144 ++++++++++++ lazyllm/tools/rag/milvus_backend.py | 270 +++++++++++++++++++++++ lazyllm/tools/rag/store.py | 327 +--------------------------- tests/basic_tests/test_index.py | 4 +- 7 files changed, 481 insertions(+), 366 deletions(-) create mode 100644 lazyllm/tools/rag/embed_utils.py create mode 100644 lazyllm/tools/rag/map_backend.py create mode 100644 lazyllm/tools/rag/milvus_backend.py diff --git a/lazyllm/tools/rag/embed_utils.py b/lazyllm/tools/rag/embed_utils.py new file mode 100644 index 00000000..a4fb03c3 --- /dev/null +++ b/lazyllm/tools/rag/embed_utils.py @@ -0,0 +1,37 @@ +import os +import concurrent +from typing import Dict, Callable, List +from lazyllm import config, ThreadPoolExecutor +from .doc_node import DocNode + +# min(32, (os.cpu_count() or 1) + 4) is the default number of workers for ThreadPoolExecutor +config.add( + "max_embedding_workers", + int, + min(32, (os.cpu_count() or 1) + 4), + "MAX_EMBEDDING_WORKERS", +) + +def parallel_do_embedding(embed: Dict[str, Callable], nodes: List[DocNode]) -> List[DocNode]: + ''' + returns a list of modified nodes + ''' + modified_nodes = [] + with ThreadPoolExecutor(config["max_embedding_workers"]) as executor: + futures = [] + for node in nodes: + miss_keys = node.has_missing_embedding(embed.keys()) + if not miss_keys: + continue + modified_nodes.append(node) + for k in miss_keys: + with node._lock: + if node.has_missing_embedding(k): + future = executor.submit(node.do_embedding, {k: embed[k]}) \ + if k not in node._embedding_state else executor.submit(node.check_embedding_state, k) + node._embedding_state.add(k) + futures.append(future) + if len(futures) > 0: + for future in concurrent.futures.as_completed(futures): + future.result() + return modified_nodes diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index 4c6ecdce..f87ca1ca 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -1,50 +1,16 @@ -import concurrent -import os from typing import List, Callable, Optional, Dict, Union, Tuple from .doc_node import DocNode from .store_base import StoreBase from .index_base import IndexBase import numpy as np from .component.bm25 import BM25 -from lazyllm import LOG, config, ThreadPoolExecutor +from lazyllm import LOG from lazyllm.common import override +from .embed_utils import parallel_do_embedding +from .milvus_backend import MilvusIndex # ---------------------------------------------------------------------------- # -# min(32, (os.cpu_count() or 1) + 4) is the default number of workers for ThreadPoolExecutor -config.add( - "max_embedding_workers", - int, - min(32, (os.cpu_count() or 1) + 4), - "MAX_EMBEDDING_WORKERS", -) - -# ---------------------------------------------------------------------------- # - -def parallel_do_embedding(embed: Dict[str, Callable], nodes: List[DocNode]) -> List[DocNode]: - ''' - returns a list of modified nodes - ''' - modified_nodes = [] - with ThreadPoolExecutor(config["max_embedding_workers"]) as executor: - futures = [] - for node in nodes: - miss_keys = node.has_missing_embedding(embed.keys()) - if not miss_keys: - continue - modified_nodes.append(node) - for k in miss_keys: - with node._lock: - if node.has_missing_embedding(k): - future = executor.submit(node.do_embedding, {k: embed[k]}) \ - if k not in node._embedding_state else executor.submit(node.check_embedding_state, k) - node._embedding_state.add(k) - futures.append(future) - if len(futures) > 0: - for future in concurrent.futures.as_completed(futures): - future.result() - return modified_nodes - class DefaultIndex(IndexBase): """Default Index, registered for similarity functions""" @@ -171,3 +137,24 @@ def register_similarity( batch: bool = False, ) -> Callable: return DefaultIndex.register_similarity(func, mode, descend, batch) + +# ---------------------------------------------------------------------------- # + +class EmbeddingIndex(IndexBase): + def __init__(self, backend_type: Optional[str] = None, *args, **kwargs): + if backend_type == 'milvus': + self._index = MilvusIndex(*args, **kwargs) + else: + raise ValueError(f'unsupported IndexWrapper backend [{backend_type}]') + + @override + def update(self, nodes: List[DocNode]) -> None: + self._index.update(nodes) + + @override + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + self._index.remove(uids, group_name) + + @override + def query(self, *args, **kwargs) -> List[DocNode]: + return self._index.query(*args, **kwargs) diff --git a/lazyllm/tools/rag/index_base.py b/lazyllm/tools/rag/index_base.py index ca2fe653..81792fe7 100644 --- a/lazyllm/tools/rag/index_base.py +++ b/lazyllm/tools/rag/index_base.py @@ -4,11 +4,11 @@ class IndexBase(ABC): @abstractmethod - def update(nodes: List[DocNode]) -> None: + def update(self, nodes: List[DocNode]) -> None: pass @abstractmethod - def remove(uids: List[str], group_name: Optional[str] = None) -> None: + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: pass @abstractmethod diff --git a/lazyllm/tools/rag/map_backend.py b/lazyllm/tools/rag/map_backend.py new file mode 100644 index 00000000..8eb039f0 --- /dev/null +++ b/lazyllm/tools/rag/map_backend.py @@ -0,0 +1,144 @@ +from typing import Dict, List, Optional +from .index_base import IndexBase +from .store_base import StoreBase +from .doc_node import DocNode +from lazyllm.common import override + +def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: + for _, index in name2index.items(): + index.update(nodes) + +def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], + group_name: Optional[str] = None) -> None: + for _, index in name2index.items(): + index.remove(uids, group_name) + +class MapBackend: + def __init__(self, node_groups: List[str]): + super().__init__() + # Dict[group_name, Dict[uuid, DocNode]] + self._group2docs: Dict[str, Dict[str, DocNode]] = { + group: {} for group in node_groups + } + self._name2index = {} + + def update_nodes(self, nodes: List[DocNode]) -> None: + for node in nodes: + self._group2docs[node.group][node.uid] = node + _update_indices(self._name2index, nodes) + + def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: + if uids: + docs = self._group2docs.get(group_name) + if docs: + _remove_from_indices(self._name2index, uids) + for uid in uids: + docs.pop(uid, None) + else: + docs = self._group2docs.pop(group_name, None) + if docs: + _remove_from_indices(self._name2index, [doc.uid for doc in docs]) + + def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: + docs = self._group2docs.get(group_name) + if not docs: + return [] + + if not uids: + return list(docs.values()) + + ret = [] + for uid in uids: + doc = docs.get(uid) + if doc: + ret.append(doc) + return ret + + def is_group_active(self, name: str) -> bool: + docs = self._group2docs.get(name) + return True if docs else False + + def all_groups(self) -> List[str]: + return self._group2docs.keys() + + def register_index(self, type: str, index: IndexBase) -> None: + self._name2index[type] = index + + def get_index(self, type: str = 'default') -> Optional[IndexBase]: + if type != 'default': + return self._name2index.get(type) + return self + + def update(self, nodes: List[DocNode]) -> None: + self.update_nodes(nodes) + + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + if group_name: + self.remove_nodes(group_name, uids) + else: + for _, docs in self._group2docs.items(): + for uid in uids: + docs.pop(uid, None) + _remove_from_indices(self._name2index, uids) + + def query(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: + return self.get_nodes(group_name, uids) + + def find_node_by_uid(self, uid: str) -> Optional[DocNode]: + for docs in self._group2docs.values(): + doc = docs.get(uid) + if doc: + return doc + return None + + +class MapIndex(IndexBase): + def __init__(self, node_groups: List[str]): + self._backend = MapBackend(node_groups) + + @override + def update(self, nodes: List[DocNode]) -> None: + self._backend.update(nodes) + + @override + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + self._backend.remove(uids, group_name) + + @override + def query(self, *args, **kwargs) -> List[DocNode]: + return self._backend.query(*args, **kwargs) + + +class MapStore(StoreBase): + def __init__(self, node_groups: List[str]): + self._backend = MapBackend(node_groups) + + @override + def update_nodes(self, nodes: List[DocNode]) -> None: + self._backend.update_nodes(nodes) + + @override + def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: + self._backend.remove_nodes(group_name, uids) + + @override + def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: + return self._backend.get_nodes(group_name, uids) + + @override + def is_group_active(self, name: str) -> bool: + return self._backend.is_group_active(name) + + @override + def all_groups(self) -> List[str]: + return self._backend.all_groups() + + @override + def register_index(self, type: str, index: IndexBase) -> None: + self._backend.register_index(type, index) + + @override + def get_index(self, type: str = 'default') -> Optional[IndexBase]: + if type == 'default': + return + return self._backend.get_index(type) diff --git a/lazyllm/tools/rag/milvus_backend.py b/lazyllm/tools/rag/milvus_backend.py new file mode 100644 index 00000000..d2658065 --- /dev/null +++ b/lazyllm/tools/rag/milvus_backend.py @@ -0,0 +1,270 @@ +import copy +from typing import Dict, List, Optional, Union, Callable +import pymilvus +from pymilvus import MilvusClient, FieldSchema, CollectionSchema +from .doc_node import DocNode +from .map_backend import MapStore +from .embed_utils import parallel_do_embedding +from .index_base import IndexBase +from .store_base import StoreBase +from lazyllm.common import override + +class MilvusField: + DTYPE_VARCHAR = 0 + DTYPE_FLOAT_VECTOR = 1 + DTYPE_SPARSE_FLOAT_VECTOR = 2 + + def __init__(self, name: str, data_type: int, index_type: Optional[str] = None, + metric_type: Optional[str] = "", index_params: Dict = {}, + max_length: Optional[int] = None): + self.name = name + self.data_type = data_type + self.index_type = index_type + self.metric_type = metric_type + self.index_params = index_params + self.max_length = max_length + +class MilvusBackend: + _type2milvus = [ + pymilvus.DataType.VARCHAR, # DTYPE_VARCHAR + pymilvus.DataType.FLOAT_VECTOR, # DTYPE_FLOAT_VECTOR + pymilvus.DataType.SPARSE_FLOAT_VECTOR, # DTYPE_SPARSE_FLOAT_VECTOR + ] + + def __init__(self, uri: str, embed: Dict[str, Callable], + # a field is either an embedding key or a metadata key + group_fields: Dict[str, List[MilvusField]]): + self._primary_key = 'uid' + self._embedding_keys = embed.keys() + self._embed = embed + self._client = MilvusClient(uri=uri) + + embed_dim = {k: len(e('a')) for k, e in embed.items()} + builtin_fields = [ + FieldSchema(name=self._primary_key, dtype=pymilvus.DataType.VARCHAR, + max_length=128, is_primary=True), + FieldSchema(name='text', dtype=pymilvus.DataType.VARCHAR, + max_length=65535), + FieldSchema(name='parent', dtype=pymilvus.DataType.VARCHAR, + max_length=256), + ] + + for group_name, field_list in group_fields.items(): + if group_name in self._client.list_collections(): + continue + + index_params = self._client.prepare_index_params() + field_schema_list = copy.copy(builtin_fields) + + for field in field_list: + field_schema = None + if field.name in self._embedding_keys: + field_schema = FieldSchema( + name=self._gen_embedding_key(field.name), + dtype=self._type2milvus[field.data_type], + dim=embed_dim.get(field.name)) + else: + field_schema = FieldSchema( + name=self._gen_metadata_key(field.name), + dtype=self._type2milvus[field.data_type], + max_length=field.max_length) + field_schema_list.append(field_schema) + + if field.index_type is not None: + index_params.add_index(field_name=field_schema.name, + index_type=field.index_type, + metric_type=field.metric_type, + params=field.index_params) + + schema = CollectionSchema(fields=field_schema_list) + self._client.create_collection(collection_name=group_name, schema=schema, + index_params=index_params) + + self._map_backend = MapStore(list(group_fields.keys())) + self._load_all_nodes_to(self._map_backend) + + # ----- APIs for Store ----- # + + def update_nodes(self, nodes: List[DocNode]) -> None: + parallel_do_embedding(self._embed, nodes) + for node in nodes: + data = self._serialize_node_partial(node) + self._client.upsert(collection_name=node.group, data=[data]) + + self._map_backend.update_nodes(nodes) + + def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: + if uids: + self._client.delete(collection_name=group_name, + filter=f'{self._primary_key} in {uids}') + else: + self._client.drop_collection(collection_name=group_name) + + self._map_backend.remove_nodes(group_name, uids) + + def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: + return self._map_backend.get_nodes(group_name, uids) + + def is_group_active(self, name: str) -> bool: + return self._map_backend.is_group_active(name) + + def all_groups(self) -> List[str]: + return self._map_backend.all_groups() + + def register_index(self, type: str, index: IndexBase) -> None: + self._map_backend.register_index(type, index) + + def get_index(self, type: str = 'default') -> Optional[IndexBase]: + if type != 'default': + return self._map_backend.get_index(type) + return self + + # ----- APIs for Index ----- # + + def update(self, nodes: List[DocNode]) -> None: + self.update_nodes(nodes) + + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + self.remove_nodes(group_name, uids) + + def query(self, + query: str, + group_name: str, + similarity: Optional[str] = None, + similarity_cut_off: Optional[Union[float, Dict[str, float]]] = None, + topk: int = 10, + embed_keys: Optional[List[str]] = None, + **kwargs) -> List[DocNode]: + uidset = set() + for key in embed_keys: + embed_func = self._embed.get(key) + query_embedding = embed_func(query) + results = self._client.search(collection_name=group_name, data=[query_embedding], + limit=topk, anns_field=self._gen_embedding_key(key)) + # we have only one `data` for search() so there is only one result in `results` + if len(results) != 1: + raise ValueError(f'number of results [{len(results)}] != expected [1]') + + for result in results[0]: + uidset.update(result['id']) + + return self._map_backend.get_nodes(group_name, list(uidset)) + + # ----- internal helper functions ----- # + + @staticmethod + def _gen_embedding_key(k: str) -> str: + return 'embedding_' + k + + @staticmethod + def _gen_metadata_key(k: str) -> str: + return 'metadata_' + k + + def _load_all_nodes_to(self, store: MapStore): + for group_name in self._client.list_collections(): + results = self._client.query(collection_name=group_name, + filter=f'{self._primary_key} != ""') + for result in results: + doc = self._deserialize_node_partial(result) + doc.group = group_name + store.update_nodes([doc], group_name) + + # construct DocNode::parent and DocNode::children + for group in self.all_groups(): + for node in self.get_nodes(group): + if node.parent: + parent_uid = node.parent + parent_node = self._map_backend.find_node_by_uid(parent_uid) + node.parent = parent_node + parent_node.children[node.group].append(node) + + def _serialize_node_partial(self, node: DocNode) -> Dict: + res = { + 'uid': node.uid, + 'text': node.text, + } + + if node.parent: + res['parent'] = node.parent.uid + else: + res['parent'] = '' + + for k, v in node.embedding.items(): + res[self._gen_embedding_key(k)] = v + for k, v in node.metadata.items(): + res[self._gen_metadata_key(k)] = v + + return res + + def _deserialize_node_partial(self, result: Dict) -> DocNode: + ''' + without parent and children + ''' + doc = DocNode( + uid=result.get('uid'), + text=result.get('text'), + parent=result.get('parent'), # this is the parent's uid + ) + + for k in self._embedding_keys: + val = result.get(self._gen_embedding_key(k)) + if val: + doc.embedding[k] = val + for k in self._metadata_keys: + val = result.get(self._gen_metadata_key(k)) + if val: + doc._metadata[k] = val + + return doc + + +class MilvusStore(StoreBase): + def __init__(self, uri: str, embed: Dict[str, Callable], + group_fields: Dict[str, List[MilvusField]]): + self._backend = MilvusBackend(uri, embed, group_fields) + + @override + def update_nodes(self, nodes: List[DocNode]) -> None: + self._backend.update_nodes(nodes) + + @override + def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: + self._backend.remove_nodes(group_name, uids) + + @override + def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: + return self._backend.get_nodes(group_name, uids) + + @override + def is_group_active(self, name: str) -> bool: + return self._backend.is_group_active(name) + + @override + def all_groups(self) -> List[str]: + return self._backend.all_groups() + + @override + def register_index(self, type: str, index: IndexBase) -> None: + self._backend.register_index(type, index) + + @override + def get_index(self, type: str = 'default') -> Optional[IndexBase]: + return self._backend.get_index(type) + + +class MilvusIndex(IndexBase): + def __init__(self, uri: str, embed: Dict[str, Callable], + group_fields: Dict[str, List[MilvusField]]): + self._backend = MilvusBackend(uri, embed, group_fields) + + @override + def update(self, nodes: List[DocNode]) -> None: + self._backend.update(nodes) + + @override + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + self._backend.remove(uids, group_name) + + @override + def query(self, *args, **kwargs) -> List[DocNode]: + return self._backend.query(*args, **kwargs) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index ec32bf03..dd2bfc41 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -1,16 +1,12 @@ -import copy -from typing import Any, Dict, List, Optional, Callable, Union +from typing import Any, Dict, List, Optional import chromadb from lazyllm import LOG, config from lazyllm.common import override from chromadb.api.models.Collection import Collection from .store_base import StoreBase -from .index_base import IndexBase -from .index import parallel_do_embedding from .doc_node import DocNode import json -import pymilvus -from pymilvus import MilvusClient, FieldSchema, CollectionSchema +from .map_backend import MapStore # ---------------------------------------------------------------------------- # @@ -21,105 +17,6 @@ # ---------------------------------------------------------------------------- # -def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: - for _, index in name2index.items(): - index.update(nodes) - -def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], - group_name: Optional[str] = None) -> None: - for _, index in name2index.items(): - index.remove(uids, group_name) - -class MapStore(StoreBase, IndexBase): - def __init__(self, node_groups: List[str]): - super().__init__() - # Dict[group_name, Dict[uuid, DocNode]] - self._group2docs: Dict[str, Dict[str, DocNode]] = { - group: {} for group in node_groups - } - self._name2index = {} - - @override - def update_nodes(self, nodes: List[DocNode]) -> None: - for node in nodes: - self._group2docs[node.group][node.uid] = node - _update_indices(self._name2index, nodes) - - @override - def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: - if uids: - docs = self._group2docs.get(group_name) - if docs: - _remove_from_indices(self._name2index, uids) - for uid in uids: - docs.pop(uid, None) - else: - docs = self._group2docs.pop(group_name, None) - if docs: - _remove_from_indices(self._name2index, [doc.uid for doc in docs]) - - @override - def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: - docs = self._group2docs.get(group_name) - if not docs: - return [] - - if not uids: - return list(docs.values()) - - ret = [] - for uid in uids: - doc = docs.get(uid) - if doc: - ret.append(doc) - return ret - - @override - def is_group_active(self, name: str) -> bool: - docs = self._group2docs.get(name) - return True if docs else False - - @override - def all_groups(self) -> List[str]: - return self._group2docs.keys() - - @override - def register_index(self, type: str, index: IndexBase) -> None: - self._name2index[type] = index - - @override - def get_index(self, type: str = 'default') -> Optional[IndexBase]: - if type != 'default': - return self._name2index.get(type) - return self - - @override - def update(self, nodes: List[DocNode]) -> None: - self.update_nodes(nodes) - - @override - def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - if group_name: - self.remove_nodes(group_name, uids) - else: - for _, docs in self._group2docs.items(): - for uid in uids: - docs.pop(uid, None) - _remove_from_indices(self._name2index, uids) - - @override - def query(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: - return self.get_nodes(group_name, uids) - - def find_node_by_uid(self, uid: str) -> Optional[DocNode]: - for docs in self._group2docs.values(): - doc = docs.get(uid) - if doc: - return doc - return None - -# ---------------------------------------------------------------------------- # - class ChromadbStore(StoreBase): def __init__( self, node_groups: List[str], embed_dim: Dict[str, int] @@ -259,223 +156,3 @@ def _peek_all_documents(self, group: str) -> Dict[str, List]: assert group in self._collections, f"group {group} not found." collection = self._collections[group] return collection.peek(collection.count()) - -# ---------------------------------------------------------------------------- # - -class MilvusField: - DTYPE_VARCHAR = 0 - DTYPE_FLOAT_VECTOR = 1 - DTYPE_SPARSE_FLOAT_VECTOR = 2 - - def __init__(self, name: str, data_type: int, index_type: Optional[str] = None, - metric_type: Optional[str] = "", index_params: Dict = {}, - max_length: Optional[int] = None): - self.name = name - self.data_type = data_type - self.index_type = index_type - self.metric_type = metric_type - self.index_params = index_params - self.max_length = max_length - -class MilvusStore(StoreBase, IndexBase): - _type2milvus = [ - pymilvus.DataType.VARCHAR, # DTYPE_VARCHAR - pymilvus.DataType.FLOAT_VECTOR, # DTYPE_FLOAT_VECTOR - pymilvus.DataType.SPARSE_FLOAT_VECTOR, # DTYPE_SPARSE_FLOAT_VECTOR - ] - - def __init__(self, uri: str, embed: Dict[str, Callable], - # a field is either an embedding key or a metadata key - group_fields: Dict[str, List[MilvusField]]): - self._primary_key = 'uid' - self._embedding_keys = embed.keys() - self._embed = embed - self._client = MilvusClient(uri=uri) - - embed_dim = {k: len(e('a')) for k, e in embed.items()} - builtin_fields = [ - FieldSchema(name=self._primary_key, dtype=pymilvus.DataType.VARCHAR, - max_length=128, is_primary=True), - FieldSchema(name='text', dtype=pymilvus.DataType.VARCHAR, - max_length=65535), - FieldSchema(name='parent', dtype=pymilvus.DataType.VARCHAR, - max_length=256), - ] - - for group_name, field_list in group_fields.items(): - if group_name in self._client.list_collections(): - continue - - index_params = self._client.prepare_index_params() - field_schema_list = copy.copy(builtin_fields) - - for field in field_list: - field_schema = None - if field.name in self._embedding_keys: - field_schema = FieldSchema( - name=self._gen_embedding_key(field.name), - dtype=self._type2milvus[field.data_type], - dim=embed_dim.get(field.name)) - else: - field_schema = FieldSchema( - name=self._gen_metadata_key(field.name), - dtype=self._type2milvus[field.data_type], - max_length=field.max_length) - field_schema_list.append(field_schema) - - if field.index_type is not None: - index_params.add_index(field_name=field_schema.name, - index_type=field.index_type, - metric_type=field.metric_type, - params=field.index_params) - - schema = CollectionSchema(fields=field_schema_list) - self._client.create_collection(collection_name=group_name, schema=schema, - index_params=index_params) - - self._map_store = MapStore(list(group_fields.keys())) - self._load_all_nodes_to(self._map_store) - - # ----- Store APIs ----- # - - @override - def update_nodes(self, nodes: List[DocNode]) -> None: - parallel_do_embedding(self._embed, nodes) - for node in nodes: - data = self._serialize_node_partial(node) - self._client.upsert(collection_name=node.group, data=[data]) - - self._map_store.update_nodes(nodes) - - @override - def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: - if uids: - self._client.delete(collection_name=group_name, - filter=f'{self._primary_key} in {uids}') - else: - self._client.drop_collection(collection_name=group_name) - - self._map_store.remove_nodes(group_name, uids) - - @override - def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: - return self._map_store.get_nodes(group_name, uids) - - @override - def is_group_active(self, name: str) -> bool: - return self._map_store.is_group_active(name) - - @override - def all_groups(self) -> List[str]: - return self._map_store.all_groups() - - @override - def register_index(self, type: str, index: IndexBase) -> None: - self._map_store.register_index(type, index) - - @override - def get_index(self, type: str = 'default') -> Optional[IndexBase]: - if type != 'default': - return self._map_store.get_index(type) - return self - - # ----- Index APIs ----- # - - @override - def update(self, nodes: List[DocNode]) -> None: - self.update_nodes(nodes) - - @override - def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - self.remove_nodes(group_name, uids) - - @override - def query(self, - query: str, - group_name: str, - similarity: Optional[str] = None, - similarity_cut_off: Optional[Union[float, Dict[str, float]]] = None, - topk: int = 10, - embed_keys: Optional[List[str]] = None, - **kwargs) -> List[DocNode]: - uidset = set() - for key in embed_keys: - embed_func = self._embed.get(key) - query_embedding = embed_func(query) - results = self._client.search(collection_name=group_name, data=[query_embedding], - limit=topk, anns_field=self._gen_embedding_key(key)) - # we have only one `data` for search() so there is only one result in `results` - if len(results) != 1: - raise ValueError(f'number of results [{len(results)}] != expected [1]') - - for result in results[0]: - uidset.update(result['id']) - - return self._map_store.get_nodes(group_name, list(uidset)) - - # ----- internal helper functions ----- # - - @staticmethod - def _gen_embedding_key(k: str) -> str: - return 'embedding_' + k - - @staticmethod - def _gen_metadata_key(k: str) -> str: - return 'metadata_' + k - - def _load_all_nodes_to(self, store: StoreBase): - for group_name in self._client.list_collections(): - results = self._client.query(collection_name=group_name, - filter=f'{self._primary_key} != ""') - for result in results: - doc = self._deserialize_node_partial(result) - doc.group = group_name - store.update_nodes([doc], group_name) - - # construct DocNode::parent and DocNode::children - for group in self.all_groups(): - for node in self.get_nodes(group): - if node.parent: - parent_uid = node.parent - parent_node = self._map_store.find_node_by_uid(parent_uid) - node.parent = parent_node - parent_node.children[node.group].append(node) - - def _serialize_node_partial(self, node: DocNode) -> Dict: - res = { - 'uid': node.uid, - 'text': node.text, - } - - if node.parent: - res['parent'] = node.parent.uid - else: - res['parent'] = '' - - for k, v in node.embedding.items(): - res[self._gen_embedding_key(k)] = v - for k, v in node.metadata.items(): - res[self._gen_metadata_key(k)] = v - - return res - - def _deserialize_node_partial(self, result: Dict) -> DocNode: - ''' - without parent and children - ''' - doc = DocNode( - uid=result.get('uid'), - text=result.get('text'), - parent=result.get('parent'), # this is the parent's uid - ) - - for k in self._embedding_keys: - val = result.get(self._gen_embedding_key(k)) - if val: - doc.embedding[k] = val - for k in self._metadata_keys: - val = result.get(self._gen_metadata_key(k)) - if val: - doc._metadata[k] = val - - return doc diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index 42c7c16c..8858d0ec 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -5,8 +5,8 @@ from unittest.mock import MagicMock from lazyllm.tools.rag.store import MapStore, LAZY_ROOT_NAME, MilvusStore, MilvusField from lazyllm.tools.rag.doc_node import DocNode -from lazyllm.tools.rag.index import DefaultIndex, register_similarity, parallel_do_embedding -import pymilvus +from lazyllm.tools.rag.index import DefaultIndex, register_similarity +from lazyllm.tools.rag.embed_utils import parallel_do_embedding class TestDefaultIndex(unittest.TestCase): def setUp(self): From 2ad88e4b60fad65b900b9a1cfe49cb612b28ec2e Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 31 Oct 2024 15:29:21 +0800 Subject: [PATCH 18/60] s --- lazyllm/tools/rag/index.py | 5 +- lazyllm/tools/rag/map_backend.py | 8 +-- lazyllm/tools/rag/milvus_backend.py | 38 ++++++------- tests/basic_tests/test_index.py | 62 +-------------------- tests/basic_tests/test_milvus_backend.py | 68 ++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 85 deletions(-) create mode 100644 tests/basic_tests/test_milvus_backend.py diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index f87ca1ca..52e6283a 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -7,7 +7,7 @@ from lazyllm import LOG from lazyllm.common import override from .embed_utils import parallel_do_embedding -from .milvus_backend import MilvusIndex +from .milvus_backend import MilvusBackend, _MilvusIndex # ---------------------------------------------------------------------------- # @@ -143,7 +143,8 @@ def register_similarity( class EmbeddingIndex(IndexBase): def __init__(self, backend_type: Optional[str] = None, *args, **kwargs): if backend_type == 'milvus': - self._index = MilvusIndex(*args, **kwargs) + backend = MilvusBackend(*args, **kwargs) + self._index = _MilvusIndex(backend) else: raise ValueError(f'unsupported IndexWrapper backend [{backend_type}]') diff --git a/lazyllm/tools/rag/map_backend.py b/lazyllm/tools/rag/map_backend.py index 8eb039f0..4d2c7c22 100644 --- a/lazyllm/tools/rag/map_backend.py +++ b/lazyllm/tools/rag/map_backend.py @@ -92,9 +92,9 @@ def find_node_by_uid(self, uid: str) -> Optional[DocNode]: return None -class MapIndex(IndexBase): - def __init__(self, node_groups: List[str]): - self._backend = MapBackend(node_groups) +class _MapIndex(IndexBase): + def __init__(self, backend: MapBackend): + self._backend = backend @override def update(self, nodes: List[DocNode]) -> None: @@ -140,5 +140,5 @@ def register_index(self, type: str, index: IndexBase) -> None: @override def get_index(self, type: str = 'default') -> Optional[IndexBase]: if type == 'default': - return + return _MapIndex(self._backend) return self._backend.get_index(type) diff --git a/lazyllm/tools/rag/milvus_backend.py b/lazyllm/tools/rag/milvus_backend.py index d2658065..2bb703c1 100644 --- a/lazyllm/tools/rag/milvus_backend.py +++ b/lazyllm/tools/rag/milvus_backend.py @@ -24,6 +24,7 @@ def __init__(self, name: str, data_type: int, index_type: Optional[str] = None, self.index_params = index_params self.max_length = max_length + class MilvusBackend: _type2milvus = [ pymilvus.DataType.VARCHAR, # DTYPE_VARCHAR @@ -218,6 +219,23 @@ def _deserialize_node_partial(self, result: Dict) -> DocNode: return doc +class _MilvusIndex(IndexBase): + def __init__(self, backend: MilvusBackend): + self._backend = backend + + @override + def update(self, nodes: List[DocNode]) -> None: + self._backend.update(nodes) + + @override + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + self._backend.remove(uids, group_name) + + @override + def query(self, *args, **kwargs) -> List[DocNode]: + return self._backend.query(*args, **kwargs) + + class MilvusStore(StoreBase): def __init__(self, uri: str, embed: Dict[str, Callable], group_fields: Dict[str, List[MilvusField]]): @@ -249,22 +267,6 @@ def register_index(self, type: str, index: IndexBase) -> None: @override def get_index(self, type: str = 'default') -> Optional[IndexBase]: + if type == 'default': + return _MilvusIndex(self._backend) return self._backend.get_index(type) - - -class MilvusIndex(IndexBase): - def __init__(self, uri: str, embed: Dict[str, Callable], - group_fields: Dict[str, List[MilvusField]]): - self._backend = MilvusBackend(uri, embed, group_fields) - - @override - def update(self, nodes: List[DocNode]) -> None: - self._backend.update(nodes) - - @override - def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - self._backend.remove(uids, group_name) - - @override - def query(self, *args, **kwargs) -> List[DocNode]: - return self._backend.query(*args, **kwargs) diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index 8858d0ec..4b306744 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -1,9 +1,7 @@ -import os import time import unittest -import tempfile from unittest.mock import MagicMock -from lazyllm.tools.rag.store import MapStore, LAZY_ROOT_NAME, MilvusStore, MilvusField +from lazyllm.tools.rag.store import MapStore from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag.index import DefaultIndex, register_similarity from lazyllm.tools.rag.embed_utils import parallel_do_embedding @@ -102,63 +100,5 @@ def test_query_multi_embed_one_thresholds(self): self.assertEqual(len(results), 1) self.assertIn(self.doc_node_2, results) -class TestMilvusIndex(unittest.TestCase): - def setUp(self): - field_list = [ - MilvusField(name="comment", data_type=MilvusField.DTYPE_VARCHAR, max_length=128), - MilvusField(name="vec1", data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), - MilvusField(name="vec2", data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), - ] - group_fields = { - "group1": field_list, - "group2": field_list, - } - - self.mock_embed = { - 'vec1': MagicMock(return_value=[1.0, 2.0, 3.0]), - 'vec2': MagicMock(return_value=[400.0, 500.0, 600.0, 700.0, 800.0]), - } - - self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - _, self.store_file = tempfile.mkstemp(suffix=".db") - - self.store = MilvusStore(uri=self.store_file, embed=self.mock_embed, - group_fields=group_fields) - self.index = self.store.get_index() - - self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, - embedding={"vec1": [8.0, 9.0, 10.0], "vec2": [11.0, 12.0, 13.0, 14.0, 15.0]}, - metadata={'comment': 'comment1'}) - self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, - embedding={"vec1": [100.0, 200.0, 300.0], "vec2": [400.0, 500.0, 600.0, 700.0, 800.0]}, - metadata={'comment': 'comment2'}) - - def tearDown(self): - os.remove(self.store_file) - - def test_update_and_query(self): - self.store.update_nodes([self.node1]) - ret = self.index.query(query='text1', group_name='group1', embed_keys=['vec2'], topk=1) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].uid, self.node1.uid) - - self.store.update_nodes([self.node2]) - ret = self.index.query(query='text2', group_name='group1', embed_keys=['vec2'], topk=1) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].uid, self.node2.uid) - - def test_remove_and_query(self): - self.store.update_nodes([self.node1, self.node2]) - ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].uid, self.node2.uid) - - self.store.remove_nodes("group1", [self.node2.uid]) - ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].uid, self.node1.uid) - if __name__ == "__main__": unittest.main() diff --git a/tests/basic_tests/test_milvus_backend.py b/tests/basic_tests/test_milvus_backend.py new file mode 100644 index 00000000..8e79a0ae --- /dev/null +++ b/tests/basic_tests/test_milvus_backend.py @@ -0,0 +1,68 @@ +import os +import unittest +import tempfile +from unittest.mock import MagicMock +from lazyllm.tools.rag.doc_node import DocNode +from lazyllm.tools.rag.store import LAZY_ROOT_NAME, MilvusStore, MilvusField + +class TestMilvusBackend(unittest.TestCase): + def setUp(self): + field_list = [ + MilvusField(name="comment", data_type=MilvusField.DTYPE_VARCHAR, max_length=128), + MilvusField(name="vec1", data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), + MilvusField(name="vec2", data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), + ] + group_fields = { + "group1": field_list, + "group2": field_list, + } + + self.mock_embed = { + 'vec1': MagicMock(return_value=[1.0, 2.0, 3.0]), + 'vec2': MagicMock(return_value=[400.0, 500.0, 600.0, 700.0, 800.0]), + } + + self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] + _, self.store_file = tempfile.mkstemp(suffix=".db") + + self.store = MilvusStore(uri=self.store_file, embed=self.mock_embed, + group_fields=group_fields) + self.index = self.store.get_index() + + self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, + embedding={"vec1": [8.0, 9.0, 10.0], "vec2": [11.0, 12.0, 13.0, 14.0, 15.0]}, + metadata={'comment': 'comment1'}) + self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, + embedding={"vec1": [100.0, 200.0, 300.0], "vec2": [400.0, 500.0, 600.0, 700.0, 800.0]}, + metadata={'comment': 'comment2'}) + + def tearDown(self): + os.remove(self.store_file) + + def test_update_and_query(self): + self.store.update_nodes([self.node1]) + ret = self.index.query(query='text1', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node1.uid) + + self.store.update_nodes([self.node2]) + ret = self.index.query(query='text2', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node2.uid) + + def test_remove_and_query(self): + self.store.update_nodes([self.node1, self.node2]) + ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node2.uid) + + self.store.remove_nodes("group1", [self.node2.uid]) + ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node1.uid) + + +if __name__ == "__main__": + unittest.main() From cd2fa3b820cba890d20dd966e3eac39cdad8a4a1 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 31 Oct 2024 15:30:48 +0800 Subject: [PATCH 19/60] s --- tests/basic_tests/test_milvus_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/basic_tests/test_milvus_backend.py b/tests/basic_tests/test_milvus_backend.py index 8e79a0ae..ce10a401 100644 --- a/tests/basic_tests/test_milvus_backend.py +++ b/tests/basic_tests/test_milvus_backend.py @@ -3,7 +3,8 @@ import tempfile from unittest.mock import MagicMock from lazyllm.tools.rag.doc_node import DocNode -from lazyllm.tools.rag.store import LAZY_ROOT_NAME, MilvusStore, MilvusField +from lazyllm.tools.rag.store import LAZY_ROOT_NAME +from lazyllm.tools.rag.milvus_backend import MilvusStore, MilvusField class TestMilvusBackend(unittest.TestCase): def setUp(self): From 6ba7f51d70d7a61bd7e7fbd13cbf98e74d3d8c82 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 31 Oct 2024 15:58:00 +0800 Subject: [PATCH 20/60] s --- lazyllm/tools/rag/map_backend.py | 74 +++++++---------------------- lazyllm/tools/rag/milvus_backend.py | 71 ++++++--------------------- 2 files changed, 32 insertions(+), 113 deletions(-) diff --git a/lazyllm/tools/rag/map_backend.py b/lazyllm/tools/rag/map_backend.py index 4d2c7c22..4c30ed0f 100644 --- a/lazyllm/tools/rag/map_backend.py +++ b/lazyllm/tools/rag/map_backend.py @@ -13,7 +13,7 @@ def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], for _, index in name2index.items(): index.remove(uids, group_name) -class MapBackend: +class MapBackend(StoreBase, IndexBase): def __init__(self, node_groups: List[str]): super().__init__() # Dict[group_name, Dict[uuid, DocNode]] @@ -22,11 +22,15 @@ def __init__(self, node_groups: List[str]): } self._name2index = {} + # ----- APIs for StoreBase ----- # + + @override def update_nodes(self, nodes: List[DocNode]) -> None: for node in nodes: self._group2docs[node.group][node.uid] = node _update_indices(self._name2index, nodes) + @override def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: if uids: docs = self._group2docs.get(group_name) @@ -39,6 +43,7 @@ def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: if docs: _remove_from_indices(self._name2index, [doc.uid for doc in docs]) + @override def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: docs = self._group2docs.get(group_name) if not docs: @@ -54,24 +59,32 @@ def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: ret.append(doc) return ret + @override def is_group_active(self, name: str) -> bool: docs = self._group2docs.get(name) return True if docs else False + @override def all_groups(self) -> List[str]: return self._group2docs.keys() + @override def register_index(self, type: str, index: IndexBase) -> None: self._name2index[type] = index + @override def get_index(self, type: str = 'default') -> Optional[IndexBase]: - if type != 'default': - return self._name2index.get(type) - return self + if type == 'default': + return self + return self._name2index.get(type) + # ----- APIs for IndexBase ----- # + + @override def update(self, nodes: List[DocNode]) -> None: self.update_nodes(nodes) + @override def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: if group_name: self.remove_nodes(group_name, uids) @@ -81,6 +94,7 @@ def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: docs.pop(uid, None) _remove_from_indices(self._name2index, uids) + @override def query(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: return self.get_nodes(group_name, uids) @@ -90,55 +104,3 @@ def find_node_by_uid(self, uid: str) -> Optional[DocNode]: if doc: return doc return None - - -class _MapIndex(IndexBase): - def __init__(self, backend: MapBackend): - self._backend = backend - - @override - def update(self, nodes: List[DocNode]) -> None: - self._backend.update(nodes) - - @override - def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - self._backend.remove(uids, group_name) - - @override - def query(self, *args, **kwargs) -> List[DocNode]: - return self._backend.query(*args, **kwargs) - - -class MapStore(StoreBase): - def __init__(self, node_groups: List[str]): - self._backend = MapBackend(node_groups) - - @override - def update_nodes(self, nodes: List[DocNode]) -> None: - self._backend.update_nodes(nodes) - - @override - def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: - self._backend.remove_nodes(group_name, uids) - - @override - def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: - return self._backend.get_nodes(group_name, uids) - - @override - def is_group_active(self, name: str) -> bool: - return self._backend.is_group_active(name) - - @override - def all_groups(self) -> List[str]: - return self._backend.all_groups() - - @override - def register_index(self, type: str, index: IndexBase) -> None: - self._backend.register_index(type, index) - - @override - def get_index(self, type: str = 'default') -> Optional[IndexBase]: - if type == 'default': - return _MapIndex(self._backend) - return self._backend.get_index(type) diff --git a/lazyllm/tools/rag/milvus_backend.py b/lazyllm/tools/rag/milvus_backend.py index 2bb703c1..8773ada0 100644 --- a/lazyllm/tools/rag/milvus_backend.py +++ b/lazyllm/tools/rag/milvus_backend.py @@ -25,7 +25,7 @@ def __init__(self, name: str, data_type: int, index_type: Optional[str] = None, self.max_length = max_length -class MilvusBackend: +class MilvusBackend(StoreBase, IndexBase): _type2milvus = [ pymilvus.DataType.VARCHAR, # DTYPE_VARCHAR pymilvus.DataType.FLOAT_VECTOR, # DTYPE_FLOAT_VECTOR @@ -86,6 +86,7 @@ def __init__(self, uri: str, embed: Dict[str, Callable], # ----- APIs for Store ----- # + @override def update_nodes(self, nodes: List[DocNode]) -> None: parallel_do_embedding(self._embed, nodes) for node in nodes: @@ -94,6 +95,7 @@ def update_nodes(self, nodes: List[DocNode]) -> None: self._map_backend.update_nodes(nodes) + @override def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: if uids: self._client.delete(collection_name=group_name, @@ -103,31 +105,39 @@ def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> Non self._map_backend.remove_nodes(group_name, uids) + @override def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: return self._map_backend.get_nodes(group_name, uids) + @override def is_group_active(self, name: str) -> bool: return self._map_backend.is_group_active(name) + @override def all_groups(self) -> List[str]: return self._map_backend.all_groups() + @override def register_index(self, type: str, index: IndexBase) -> None: self._map_backend.register_index(type, index) + @override def get_index(self, type: str = 'default') -> Optional[IndexBase]: - if type != 'default': - return self._map_backend.get_index(type) - return self + if type == 'default': + return self + return self._map_backend.get_index(type) # ----- APIs for Index ----- # + @override def update(self, nodes: List[DocNode]) -> None: self.update_nodes(nodes) + @override def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: self.remove_nodes(group_name, uids) + @override def query(self, query: str, group_name: str, @@ -217,56 +227,3 @@ def _deserialize_node_partial(self, result: Dict) -> DocNode: doc._metadata[k] = val return doc - - -class _MilvusIndex(IndexBase): - def __init__(self, backend: MilvusBackend): - self._backend = backend - - @override - def update(self, nodes: List[DocNode]) -> None: - self._backend.update(nodes) - - @override - def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - self._backend.remove(uids, group_name) - - @override - def query(self, *args, **kwargs) -> List[DocNode]: - return self._backend.query(*args, **kwargs) - - -class MilvusStore(StoreBase): - def __init__(self, uri: str, embed: Dict[str, Callable], - group_fields: Dict[str, List[MilvusField]]): - self._backend = MilvusBackend(uri, embed, group_fields) - - @override - def update_nodes(self, nodes: List[DocNode]) -> None: - self._backend.update_nodes(nodes) - - @override - def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: - self._backend.remove_nodes(group_name, uids) - - @override - def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: - return self._backend.get_nodes(group_name, uids) - - @override - def is_group_active(self, name: str) -> bool: - return self._backend.is_group_active(name) - - @override - def all_groups(self) -> List[str]: - return self._backend.all_groups() - - @override - def register_index(self, type: str, index: IndexBase) -> None: - self._backend.register_index(type, index) - - @override - def get_index(self, type: str = 'default') -> Optional[IndexBase]: - if type == 'default': - return _MilvusIndex(self._backend) - return self._backend.get_index(type) From a05996da1667e6d0b28de487bdcda03e05151c3f Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 31 Oct 2024 16:06:11 +0800 Subject: [PATCH 21/60] s --- lazyllm/tools/rag/index.py | 5 ++--- lazyllm/tools/rag/milvus_backend.py | 6 +++--- lazyllm/tools/rag/store.py | 4 ++-- tests/basic_tests/test_milvus_backend.py | 6 +++--- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index 52e6283a..b08b7dc5 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -7,7 +7,7 @@ from lazyllm import LOG from lazyllm.common import override from .embed_utils import parallel_do_embedding -from .milvus_backend import MilvusBackend, _MilvusIndex +from .milvus_backend import MilvusBackend # ---------------------------------------------------------------------------- # @@ -143,8 +143,7 @@ def register_similarity( class EmbeddingIndex(IndexBase): def __init__(self, backend_type: Optional[str] = None, *args, **kwargs): if backend_type == 'milvus': - backend = MilvusBackend(*args, **kwargs) - self._index = _MilvusIndex(backend) + self._index = MilvusBackend(*args, **kwargs) else: raise ValueError(f'unsupported IndexWrapper backend [{backend_type}]') diff --git a/lazyllm/tools/rag/milvus_backend.py b/lazyllm/tools/rag/milvus_backend.py index 8773ada0..339051f9 100644 --- a/lazyllm/tools/rag/milvus_backend.py +++ b/lazyllm/tools/rag/milvus_backend.py @@ -3,7 +3,7 @@ import pymilvus from pymilvus import MilvusClient, FieldSchema, CollectionSchema from .doc_node import DocNode -from .map_backend import MapStore +from .map_backend import MapBackend from .embed_utils import parallel_do_embedding from .index_base import IndexBase from .store_base import StoreBase @@ -81,7 +81,7 @@ def __init__(self, uri: str, embed: Dict[str, Callable], self._client.create_collection(collection_name=group_name, schema=schema, index_params=index_params) - self._map_backend = MapStore(list(group_fields.keys())) + self._map_backend = MapBackend(list(group_fields.keys())) self._load_all_nodes_to(self._map_backend) # ----- APIs for Store ----- # @@ -171,7 +171,7 @@ def _gen_embedding_key(k: str) -> str: def _gen_metadata_key(k: str) -> str: return 'metadata_' + k - def _load_all_nodes_to(self, store: MapStore): + def _load_all_nodes_to(self, store: StoreBase): for group_name in self._client.list_collections(): results = self._client.query(collection_name=group_name, filter=f'{self._primary_key} != ""') diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index dd2bfc41..b1788b67 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -6,7 +6,7 @@ from .store_base import StoreBase from .doc_node import DocNode import json -from .map_backend import MapStore +from .map_backend import MapBackend # ---------------------------------------------------------------------------- # @@ -22,7 +22,7 @@ def __init__( self, node_groups: List[str], embed_dim: Dict[str, int] ) -> None: super().__init__() - self._map_store = MapStore(node_groups) + self._map_store = MapBackend(node_groups) self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") self._collections: Dict[str, Collection] = { diff --git a/tests/basic_tests/test_milvus_backend.py b/tests/basic_tests/test_milvus_backend.py index ce10a401..61585586 100644 --- a/tests/basic_tests/test_milvus_backend.py +++ b/tests/basic_tests/test_milvus_backend.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag.store import LAZY_ROOT_NAME -from lazyllm.tools.rag.milvus_backend import MilvusStore, MilvusField +from lazyllm.tools.rag.milvus_backend import MilvusBackend, MilvusField class TestMilvusBackend(unittest.TestCase): def setUp(self): @@ -28,8 +28,8 @@ def setUp(self): self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] _, self.store_file = tempfile.mkstemp(suffix=".db") - self.store = MilvusStore(uri=self.store_file, embed=self.mock_embed, - group_fields=group_fields) + self.store = MilvusBackend(uri=self.store_file, embed=self.mock_embed, + group_fields=group_fields) self.index = self.store.get_index() self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, From 1ec44ecb22942d7ea4546193d3a37070a1aca0bb Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 31 Oct 2024 20:40:53 +0800 Subject: [PATCH 22/60] s --- lazyllm/tools/rag/index.py | 10 ++++++---- .../rag/{map_backend.py => map_store.py} | 20 +------------------ .../{milvus_backend.py => milvus_store.py} | 14 +------------ lazyllm/tools/rag/store_base.py | 4 ++++ 4 files changed, 12 insertions(+), 36 deletions(-) rename lazyllm/tools/rag/{map_backend.py => map_store.py} (82%) rename lazyllm/tools/rag/{milvus_backend.py => milvus_store.py} (95%) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index b08b7dc5..d1b35c96 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -143,18 +143,20 @@ def register_similarity( class EmbeddingIndex(IndexBase): def __init__(self, backend_type: Optional[str] = None, *args, **kwargs): if backend_type == 'milvus': - self._index = MilvusBackend(*args, **kwargs) + self._store = MilvusStore(*args, **kwargs) + elif backend_type == 'map': + self._store = MapStore(*args, **kwargs) else: raise ValueError(f'unsupported IndexWrapper backend [{backend_type}]') @override def update(self, nodes: List[DocNode]) -> None: - self._index.update(nodes) + self._store.update_nodes(nodes) @override def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - self._index.remove(uids, group_name) + self._index.remove_nodes(group_name, uids) @override def query(self, *args, **kwargs) -> List[DocNode]: - return self._index.query(*args, **kwargs) + return self._store.query(*args, **kwargs) diff --git a/lazyllm/tools/rag/map_backend.py b/lazyllm/tools/rag/map_store.py similarity index 82% rename from lazyllm/tools/rag/map_backend.py rename to lazyllm/tools/rag/map_store.py index 4c30ed0f..119306ca 100644 --- a/lazyllm/tools/rag/map_backend.py +++ b/lazyllm/tools/rag/map_store.py @@ -13,7 +13,7 @@ def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], for _, index in name2index.items(): index.remove(uids, group_name) -class MapBackend(StoreBase, IndexBase): +class MapStore(StoreBase, IndexBase): def __init__(self, node_groups: List[str]): super().__init__() # Dict[group_name, Dict[uuid, DocNode]] @@ -22,8 +22,6 @@ def __init__(self, node_groups: List[str]): } self._name2index = {} - # ----- APIs for StoreBase ----- # - @override def update_nodes(self, nodes: List[DocNode]) -> None: for node in nodes: @@ -78,22 +76,6 @@ def get_index(self, type: str = 'default') -> Optional[IndexBase]: return self return self._name2index.get(type) - # ----- APIs for IndexBase ----- # - - @override - def update(self, nodes: List[DocNode]) -> None: - self.update_nodes(nodes) - - @override - def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - if group_name: - self.remove_nodes(group_name, uids) - else: - for _, docs in self._group2docs.items(): - for uid in uids: - docs.pop(uid, None) - _remove_from_indices(self._name2index, uids) - @override def query(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: return self.get_nodes(group_name, uids) diff --git a/lazyllm/tools/rag/milvus_backend.py b/lazyllm/tools/rag/milvus_store.py similarity index 95% rename from lazyllm/tools/rag/milvus_backend.py rename to lazyllm/tools/rag/milvus_store.py index 339051f9..04f2ddc4 100644 --- a/lazyllm/tools/rag/milvus_backend.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -25,7 +25,7 @@ def __init__(self, name: str, data_type: int, index_type: Optional[str] = None, self.max_length = max_length -class MilvusBackend(StoreBase, IndexBase): +class MilvusStore(StoreBase): _type2milvus = [ pymilvus.DataType.VARCHAR, # DTYPE_VARCHAR pymilvus.DataType.FLOAT_VECTOR, # DTYPE_FLOAT_VECTOR @@ -84,8 +84,6 @@ def __init__(self, uri: str, embed: Dict[str, Callable], self._map_backend = MapBackend(list(group_fields.keys())) self._load_all_nodes_to(self._map_backend) - # ----- APIs for Store ----- # - @override def update_nodes(self, nodes: List[DocNode]) -> None: parallel_do_embedding(self._embed, nodes) @@ -127,16 +125,6 @@ def get_index(self, type: str = 'default') -> Optional[IndexBase]: return self return self._map_backend.get_index(type) - # ----- APIs for Index ----- # - - @override - def update(self, nodes: List[DocNode]) -> None: - self.update_nodes(nodes) - - @override - def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - self.remove_nodes(group_name, uids) - @override def query(self, query: str, diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index 81275c4b..41d1fdf5 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -24,6 +24,10 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: pass + @abstractmethod + def query(self, *args, **kwargs) -> List[DocNode]: + pass + @abstractmethod def register_index(self, type: str, index: IndexBase) -> None: pass From f7c5327d9d33aa68914e528e686bb7624e8f1443 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 31 Oct 2024 20:42:55 +0800 Subject: [PATCH 23/60] s --- lazyllm/tools/rag/document.py | 3 ++- lazyllm/tools/rag/index.py | 3 ++- lazyllm/tools/rag/store.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index 53ac28ec..8bf48490 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -35,7 +35,8 @@ def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, if isinstance(embed, ModuleBase): self._submodules.append(embed) self._dlm = DocListManager(dataset_path, name).init_tables() - self._kbs = CallableDict({DocListManager.DEDAULT_GROUP_NAME: DocImpl(embed=self._embed, dlm=self._dlm, store=store)}) + self._kbs = CallableDict({DocListManager.DEDAULT_GROUP_NAME: + DocImpl(embed=self._embed, dlm=self._dlm, store=store)}) if manager: self._manager = ServerModule(DocManager(self._dlm)) if server: self._kbs = ServerModule(self._kbs) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index d1b35c96..6cfc127c 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -7,7 +7,8 @@ from lazyllm import LOG from lazyllm.common import override from .embed_utils import parallel_do_embedding -from .milvus_backend import MilvusBackend +from .milvus_store import MilvusStore +from .map_store import MapStore # ---------------------------------------------------------------------------- # diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 3dc7c5e3..b1788b67 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional import chromadb -from lazyllm import LOG, config, reset_on_pickle +from lazyllm import LOG, config from lazyllm.common import override from chromadb.api.models.Collection import Collection from .store_base import StoreBase From 8b4066e21971df8fd47cc13ac2d8cdb9f2c07265 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 31 Oct 2024 20:59:02 +0800 Subject: [PATCH 24/60] s --- lazyllm/tools/rag/index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index 6cfc127c..ff5c35a6 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -156,7 +156,7 @@ def update(self, nodes: List[DocNode]) -> None: @override def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - self._index.remove_nodes(group_name, uids) + self._store.remove_nodes(group_name, uids) @override def query(self, *args, **kwargs) -> List[DocNode]: From 97dc0f05cace043b2416ec5065fb50636c64f3b6 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Fri, 1 Nov 2024 10:34:02 +0800 Subject: [PATCH 25/60] s --- lazyllm/tools/rag/chroma_store.py | 152 ++++++++++++++++++++++ lazyllm/tools/rag/index.py | 17 +++ lazyllm/tools/rag/map_store.py | 4 +- lazyllm/tools/rag/milvus_store.py | 4 +- lazyllm/tools/rag/store.py | 154 +---------------------- lazyllm/tools/rag/store_base.py | 2 +- tests/basic_tests/test_milvus_backend.py | 69 ---------- tests/basic_tests/test_store.py | 63 +++++++++- 8 files changed, 235 insertions(+), 230 deletions(-) create mode 100644 lazyllm/tools/rag/chroma_store.py delete mode 100644 tests/basic_tests/test_milvus_backend.py diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py new file mode 100644 index 00000000..246fc449 --- /dev/null +++ b/lazyllm/tools/rag/chroma_store.py @@ -0,0 +1,152 @@ +from typing import Any, Dict, List, Optional +import chromadb +from lazyllm import LOG, config +from lazyllm.common import override +from chromadb.api.models.Collection import Collection +from .store_base import StoreBase +from .doc_node import DocNode +from .store import LAZY_ROOT_NAME +import json +from .map_backend import MapBackend + +# ---------------------------------------------------------------------------- # + +class ChromadbStore(StoreBase): + def __init__( + self, node_groups: List[str], embed_dim: Dict[str, int] + ) -> None: + super().__init__() + self._map_store = MapBackend(node_groups) + self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) + LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") + self._collections: Dict[str, Collection] = { + group: self._db_client.get_or_create_collection(group) + for group in node_groups + } + self._embed_dim = embed_dim + + @override + def update_nodes(self, nodes: List[DocNode]) -> None: + self._map_store.update_nodes(nodes) + self._save_nodes(nodes) + + @override + def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: + if uids: + self._delete_group_nodes(group_name, uids) + else: + self._db_client.delete_collection(name=group_name) + return self._map_store.remove_nodes(group_name, uids) + + @override + def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: + return self._map_store.get_nodes(group_name, uids) + + @override + def is_group_active(self, name: str) -> bool: + return self._map_store.is_group_active(name) + + @override + def all_groups(self) -> List[str]: + return self._map_store.all_groups() + + def _load_store(self) -> None: + if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: + LOG.info("No persistent data found, skip the rebuilding phrase.") + return + + # Restore all nodes + for group in self._collections.keys(): + results = self._peek_all_documents(group) + nodes = self._build_nodes_from_chroma(results) + self._map_store.update_nodes(nodes) + + # Rebuild relationships + for group_name in self._map_store.all_groups(): + nodes = self._map_store.get_nodes(group_name) + for node in nodes: + if node.parent: + parent_uid = node.parent + parent_node = self._map_store.find_node_by_uid(parent_uid) + node.parent = parent_node + parent_node.children[node.group].append(node) + LOG.debug(f"build {group} nodes from chromadb: {nodes}") + LOG.success("Successfully Built nodes from chromadb.") + + def _save_nodes(self, nodes: List[DocNode]) -> None: + if not nodes: + return + # Note: It's caller's duty to make sure this batch of nodes has the same group. + group = nodes[0].group + ids, embeddings, metadatas, documents = [], [], [], [] + collection = self._collections.get(group) + assert ( + collection + ), f"Group {group} is not found in collections {self._collections}" + for node in nodes: + if node.is_saved: + continue + metadata = self._make_chroma_metadata(node) + metadata["embedding"] = json.dumps(node.embedding) + ids.append(node.uid) + embeddings.append([0]) # we don't use chroma for retrieving + metadatas.append(metadata) + documents.append(node.get_text()) + node.is_saved = True + if ids: + collection.upsert( + embeddings=embeddings, + ids=ids, + metadatas=metadatas, + documents=documents, + ) + LOG.debug(f"Saved {group} nodes {ids} to chromadb.") + + def _delete_group_nodes(self, group_name: str, uids: List[str]) -> None: + collection = self._collections.get(group_name) + if collection: + collection.delete(ids=uids) + + def _build_nodes_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: + nodes: List[DocNode] = [] + for i, uid in enumerate(results['ids']): + chroma_metadata = results['metadatas'][i] + node = DocNode( + uid=uid, + text=results["documents"][i], + group=chroma_metadata["group"], + embedding=json.loads(chroma_metadata['embedding']), + parent=chroma_metadata["parent"], + ) + + if node.embedding: + # convert sparse embedding to List[float] + new_embedding_dict = {} + for key, embedding in node.embedding.items(): + if isinstance(embedding, dict): + dim = self._embed_dim.get(key) + if not dim: + raise ValueError(f'dim of embed [{key}] is not determined.') + new_embedding = [0] * dim + for idx, val in embedding.items(): + new_embedding[int(idx)] = val + new_embedding_dict[key] = new_embedding + else: + new_embedding_dict[key] = embedding + node.embedding = new_embedding_dict + + node.is_saved = True + nodes.append(node) + return nodes + + def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]: + metadata = { + "group": node.group, + "parent": node.parent.uid if node.parent else "", + } + return metadata + + def _peek_all_documents(self, group: str) -> Dict[str, List]: + assert group in self._collections, f"group {group} not found." + collection = self._collections[group] + return collection.peek(collection.count()) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index ff5c35a6..8d1c2273 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -161,3 +161,20 @@ def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: @override def query(self, *args, **kwargs) -> List[DocNode]: return self._store.query(*args, **kwargs) + + +class WrapStoreToIndex(IndexBase): + def __init__(self, store: StoreBase): + self._store = store + + @override + def update(self, nodes: List[DocNode]) -> None: + self._store.update_nodes(nodes) + + @override + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + self._store.remove_nodes(group_name, uids) + + @override + def query(self, *args, **kwargs) -> List[DocNode]: + return self._store.query(*args, **kwargs) diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index 119306ca..ee7e2842 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -71,9 +71,7 @@ def register_index(self, type: str, index: IndexBase) -> None: self._name2index[type] = index @override - def get_index(self, type: str = 'default') -> Optional[IndexBase]: - if type == 'default': - return self + def get_index(self, type: str) -> Optional[IndexBase]: return self._name2index.get(type) @override diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 04f2ddc4..9537489c 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -120,9 +120,7 @@ def register_index(self, type: str, index: IndexBase) -> None: self._map_backend.register_index(type, index) @override - def get_index(self, type: str = 'default') -> Optional[IndexBase]: - if type == 'default': - return self + def get_index(self, type: str) -> Optional[IndexBase]: return self._map_backend.get_index(type) @override diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index b1788b67..8db832c1 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -1,158 +1,6 @@ -from typing import Any, Dict, List, Optional -import chromadb -from lazyllm import LOG, config -from lazyllm.common import override -from chromadb.api.models.Collection import Collection -from .store_base import StoreBase -from .doc_node import DocNode -import json -from .map_backend import MapBackend - -# ---------------------------------------------------------------------------- # +from lazyllm import config LAZY_ROOT_NAME = "lazyllm_root" EMBED_DEFAULT_KEY = '__default__' config.add("rag_store_type", str, "map", "RAG_STORE_TYPE") # "map", "chroma" config.add("rag_persistent_path", str, "./lazyllm_chroma", "RAG_PERSISTENT_PATH") - -# ---------------------------------------------------------------------------- # - -class ChromadbStore(StoreBase): - def __init__( - self, node_groups: List[str], embed_dim: Dict[str, int] - ) -> None: - super().__init__() - self._map_store = MapBackend(node_groups) - self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) - LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") - self._collections: Dict[str, Collection] = { - group: self._db_client.get_or_create_collection(group) - for group in node_groups - } - self._embed_dim = embed_dim - - @override - def update_nodes(self, nodes: List[DocNode]) -> None: - self._map_store.update_nodes(nodes) - self._save_nodes(nodes) - - @override - def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: - if uids: - self._delete_group_nodes(group_name, uids) - else: - self._db_client.delete_collection(name=group_name) - return self._map_store.remove_nodes(group_name, uids) - - @override - def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: - return self._map_store.get_nodes(group_name, uids) - - @override - def is_group_active(self, name: str) -> bool: - return self._map_store.is_group_active(name) - - @override - def all_groups(self) -> List[str]: - return self._map_store.all_groups() - - def _load_store(self) -> None: - if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: - LOG.info("No persistent data found, skip the rebuilding phrase.") - return - - # Restore all nodes - for group in self._collections.keys(): - results = self._peek_all_documents(group) - nodes = self._build_nodes_from_chroma(results) - self._map_store.update_nodes(nodes) - - # Rebuild relationships - for group_name in self._map_store.all_groups(): - nodes = self._map_store.get_nodes(group_name) - for node in nodes: - if node.parent: - parent_uid = node.parent - parent_node = self._map_store.find_node_by_uid(parent_uid) - node.parent = parent_node - parent_node.children[node.group].append(node) - LOG.debug(f"build {group} nodes from chromadb: {nodes}") - LOG.success("Successfully Built nodes from chromadb.") - - def _save_nodes(self, nodes: List[DocNode]) -> None: - if not nodes: - return - # Note: It's caller's duty to make sure this batch of nodes has the same group. - group = nodes[0].group - ids, embeddings, metadatas, documents = [], [], [], [] - collection = self._collections.get(group) - assert ( - collection - ), f"Group {group} is not found in collections {self._collections}" - for node in nodes: - if node.is_saved: - continue - metadata = self._make_chroma_metadata(node) - metadata["embedding"] = json.dumps(node.embedding) - ids.append(node.uid) - embeddings.append([0]) # we don't use chroma for retrieving - metadatas.append(metadata) - documents.append(node.get_text()) - node.is_saved = True - if ids: - collection.upsert( - embeddings=embeddings, - ids=ids, - metadatas=metadatas, - documents=documents, - ) - LOG.debug(f"Saved {group} nodes {ids} to chromadb.") - - def _delete_group_nodes(self, group_name: str, uids: List[str]) -> None: - collection = self._collections.get(group_name) - if collection: - collection.delete(ids=uids) - - def _build_nodes_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: - nodes: List[DocNode] = [] - for i, uid in enumerate(results['ids']): - chroma_metadata = results['metadatas'][i] - node = DocNode( - uid=uid, - text=results["documents"][i], - group=chroma_metadata["group"], - embedding=json.loads(chroma_metadata['embedding']), - parent=chroma_metadata["parent"], - ) - - if node.embedding: - # convert sparse embedding to List[float] - new_embedding_dict = {} - for key, embedding in node.embedding.items(): - if isinstance(embedding, dict): - dim = self._embed_dim.get(key) - if not dim: - raise ValueError(f'dim of embed [{key}] is not determined.') - new_embedding = [0] * dim - for idx, val in embedding.items(): - new_embedding[int(idx)] = val - new_embedding_dict[key] = new_embedding - else: - new_embedding_dict[key] = embedding - node.embedding = new_embedding_dict - - node.is_saved = True - nodes.append(node) - return nodes - - def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]: - metadata = { - "group": node.group, - "parent": node.parent.uid if node.parent else "", - } - return metadata - - def _peek_all_documents(self, group: str) -> Dict[str, List]: - assert group in self._collections, f"group {group} not found." - collection = self._collections[group] - return collection.peek(collection.count()) diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index 41d1fdf5..515aad80 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -33,5 +33,5 @@ def register_index(self, type: str, index: IndexBase) -> None: pass @abstractmethod - def get_index(self, type: str = 'default') -> Optional[IndexBase]: + def get_index(self, type: str) -> Optional[IndexBase]: pass diff --git a/tests/basic_tests/test_milvus_backend.py b/tests/basic_tests/test_milvus_backend.py deleted file mode 100644 index 61585586..00000000 --- a/tests/basic_tests/test_milvus_backend.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -import unittest -import tempfile -from unittest.mock import MagicMock -from lazyllm.tools.rag.doc_node import DocNode -from lazyllm.tools.rag.store import LAZY_ROOT_NAME -from lazyllm.tools.rag.milvus_backend import MilvusBackend, MilvusField - -class TestMilvusBackend(unittest.TestCase): - def setUp(self): - field_list = [ - MilvusField(name="comment", data_type=MilvusField.DTYPE_VARCHAR, max_length=128), - MilvusField(name="vec1", data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), - MilvusField(name="vec2", data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), - ] - group_fields = { - "group1": field_list, - "group2": field_list, - } - - self.mock_embed = { - 'vec1': MagicMock(return_value=[1.0, 2.0, 3.0]), - 'vec2': MagicMock(return_value=[400.0, 500.0, 600.0, 700.0, 800.0]), - } - - self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - _, self.store_file = tempfile.mkstemp(suffix=".db") - - self.store = MilvusBackend(uri=self.store_file, embed=self.mock_embed, - group_fields=group_fields) - self.index = self.store.get_index() - - self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, - embedding={"vec1": [8.0, 9.0, 10.0], "vec2": [11.0, 12.0, 13.0, 14.0, 15.0]}, - metadata={'comment': 'comment1'}) - self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, - embedding={"vec1": [100.0, 200.0, 300.0], "vec2": [400.0, 500.0, 600.0, 700.0, 800.0]}, - metadata={'comment': 'comment2'}) - - def tearDown(self): - os.remove(self.store_file) - - def test_update_and_query(self): - self.store.update_nodes([self.node1]) - ret = self.index.query(query='text1', group_name='group1', embed_keys=['vec2'], topk=1) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].uid, self.node1.uid) - - self.store.update_nodes([self.node2]) - ret = self.index.query(query='text2', group_name='group1', embed_keys=['vec2'], topk=1) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].uid, self.node2.uid) - - def test_remove_and_query(self): - self.store.update_nodes([self.node1, self.node2]) - ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].uid, self.node2.uid) - - self.store.remove_nodes("group1", [self.node2.uid]) - ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].uid, self.node1.uid) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 154e8786..a0baad99 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -2,7 +2,9 @@ import shutil import unittest import lazyllm -from lazyllm.tools.rag.store import MapStore, ChromadbStore, LAZY_ROOT_NAME +from lazyllm.tools.rag.store import LAZY_ROOT_NAME +from lazyllm.tools.rag.map_store import MapStore +from lazyllm.tools.rag.chroma_store import ChromadbStore from lazyllm.tools.rag.doc_node import DocNode @@ -147,3 +149,62 @@ def test_group_others(self): self.store.update_nodes([self.node1, self.node2]) self.assertEqual(self.store.is_group_active("group1"), True) self.assertEqual(self.store.is_group_active("group2"), False) + + +class TestMilvusStore(unittest.TestCase): + def setUp(self): + field_list = [ + MilvusField(name="comment", data_type=MilvusField.DTYPE_VARCHAR, max_length=128), + MilvusField(name="vec1", data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), + MilvusField(name="vec2", data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), + ] + group_fields = { + "group1": field_list, + "group2": field_list, + } + + self.mock_embed = { + 'vec1': MagicMock(return_value=[1.0, 2.0, 3.0]), + 'vec2': MagicMock(return_value=[400.0, 500.0, 600.0, 700.0, 800.0]), + } + + self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] + _, self.store_file = tempfile.mkstemp(suffix=".db") + + self.store = MilvusStore(uri=self.store_file, embed=self.mock_embed, + group_fields=group_fields) + self.index = self.store.get_index() + + self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, + embedding={"vec1": [8.0, 9.0, 10.0], "vec2": [11.0, 12.0, 13.0, 14.0, 15.0]}, + metadata={'comment': 'comment1'}) + self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, + embedding={"vec1": [100.0, 200.0, 300.0], "vec2": [400.0, 500.0, 600.0, 700.0, 800.0]}, + metadata={'comment': 'comment2'}) + + def tearDown(self): + os.remove(self.store_file) + + def test_update_and_query(self): + self.store.update_nodes([self.node1]) + ret = self.index.query(query='text1', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node1.uid) + + self.store.update_nodes([self.node2]) + ret = self.index.query(query='text2', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node2.uid) + + def test_remove_and_query(self): + self.store.update_nodes([self.node1, self.node2]) + ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node2.uid) + + self.store.remove_nodes("group1", [self.node2.uid]) + ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node1.uid) From 1ca6ca86d677a078a941e96e8ce5b6fa8b1d2110 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Fri, 1 Nov 2024 10:46:51 +0800 Subject: [PATCH 26/60] s --- lazyllm/tools/rag/default_index.py | 140 +++++++++++++++++++++++++++++ lazyllm/tools/rag/index.py | 137 +--------------------------- tests/basic_tests/test_store.py | 4 +- 3 files changed, 144 insertions(+), 137 deletions(-) create mode 100644 lazyllm/tools/rag/default_index.py diff --git a/lazyllm/tools/rag/default_index.py b/lazyllm/tools/rag/default_index.py new file mode 100644 index 00000000..61bbda59 --- /dev/null +++ b/lazyllm/tools/rag/default_index.py @@ -0,0 +1,140 @@ +from typing import List, Callable, Optional, Dict, Union, Tuple +from .doc_node import DocNode +from .store_base import StoreBase +from .index_base import IndexBase +import numpy as np +from .component.bm25 import BM25 +from lazyllm import LOG +from lazyllm.common import override +from .embed_utils import parallel_do_embedding +from .milvus_store import MilvusStore +from .map_store import MapStore + +# ---------------------------------------------------------------------------- # + +class DefaultIndex(IndexBase): + """Default Index, registered for similarity functions""" + + registered_similarity = dict() + + def __init__(self, embed: Dict[str, Callable], store: StoreBase, **kwargs): + self.embed = embed + self.store = store + + @classmethod + def register_similarity( + cls: "DefaultIndex", + func: Optional[Callable] = None, + mode: str = "", + descend: bool = True, + batch: bool = False, + ) -> Callable: + def decorator(f): + def wrapper(query, nodes, **kwargs): + if mode != "embedding": + if batch: + return f(query, nodes, **kwargs) + else: + return [(node, f(query, node, **kwargs)) for node in nodes] + else: + assert isinstance(query, dict), "query must be of dict type, used for similarity calculation." + similarity = {} + if batch: + for key, val in query.items(): + nodes_embed = [node.embedding[key] for node in nodes] + similarity[key] = f(val, nodes_embed, **kwargs) + else: + for key, val in query.items(): + similarity[key] = [(node, f(val, node.embedding[key], **kwargs)) for node in nodes] + return similarity + cls.registered_similarity[f.__name__] = (wrapper, mode, descend) + return wrapper + + return decorator(func) if func else decorator + + @override + def update(self, nodes: List[DocNode]) -> None: + pass + + @override + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + pass + + @override + def query( + self, + query: str, + group_name: str, + similarity_name: str, + similarity_cut_off: Union[float, Dict[str, float]], + topk: int, + embed_keys: Optional[List[str]] = None, + **kwargs, + ) -> List[DocNode]: + if similarity_name not in self.registered_similarity: + raise ValueError( + f"{similarity_name} not registered, please check your input." + f"Available options now: {self.registered_similarity.keys()}" + ) + similarity_func, mode, descend = self.registered_similarity[similarity_name] + + nodes = self.store.get_nodes(group_name) + if mode == "embedding": + assert self.embed, "Chosen similarity needs embed model." + assert len(query) > 0, "Query should not be empty." + query_embedding = {k: self.embed[k](query) for k in (embed_keys or self.embed.keys())} + modified_nodes = parallel_do_embedding(self.embed, nodes) + self.store.update_nodes(modified_nodes) + similarities = similarity_func(query_embedding, nodes, topk=topk, **kwargs) + elif mode == "text": + similarities = similarity_func(query, nodes, topk=topk, **kwargs) + else: + raise NotImplementedError(f"Mode {mode} is not supported.") + + if not isinstance(similarities, dict): + results = self._filter_nodes_by_score(similarities, topk, similarity_cut_off, descend) + else: + results = [] + for key in (embed_keys or similarities.keys()): + sims = similarities[key] + sim_cut_off = similarity_cut_off if isinstance(similarity_cut_off, float) else similarity_cut_off[key] + results.extend(self._filter_nodes_by_score(sims, topk, sim_cut_off, descend)) + results = list(set(results)) + LOG.debug(f"Retrieving query `{query}` and get results: {results}") + return results + + def _filter_nodes_by_score(self, similarities: List[Tuple[DocNode, float]], topk: int, + similarity_cut_off: float, descend) -> List[DocNode]: + similarities.sort(key=lambda x: x[1], reverse=descend) + if topk is not None: + similarities = similarities[:topk] + + return [node for node, score in similarities if score > similarity_cut_off] + +@DefaultIndex.register_similarity(mode="text", batch=True) +def bm25(query: str, nodes: List[DocNode], **kwargs) -> List: + bm25_retriever = BM25(nodes, language="en", **kwargs) + return bm25_retriever.retrieve(query) + + +@DefaultIndex.register_similarity(mode="text", batch=True) +def bm25_chinese(query: str, nodes: List[DocNode], **kwargs) -> List: + bm25_retriever = BM25(nodes, language="zh", **kwargs) + return bm25_retriever.retrieve(query) + + +@DefaultIndex.register_similarity(mode="embedding") +def cosine(query: List[float], node: List[float], **kwargs) -> float: + product = np.dot(query, node) + norm = np.linalg.norm(query) * np.linalg.norm(node) + return product / norm + + +# User-defined similarity decorator +def register_similarity( + func: Optional[Callable] = None, + mode: str = "", + descend: bool = True, + batch: bool = False, +) -> Callable: + return DefaultIndex.register_similarity(func, mode, descend, batch) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index 8d1c2273..5f254efa 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -1,146 +1,11 @@ -from typing import List, Callable, Optional, Dict, Union, Tuple +from typing import List, Optional from .doc_node import DocNode from .store_base import StoreBase from .index_base import IndexBase -import numpy as np -from .component.bm25 import BM25 -from lazyllm import LOG from lazyllm.common import override -from .embed_utils import parallel_do_embedding from .milvus_store import MilvusStore from .map_store import MapStore -# ---------------------------------------------------------------------------- # - -class DefaultIndex(IndexBase): - """Default Index, registered for similarity functions""" - - registered_similarity = dict() - - def __init__(self, embed: Dict[str, Callable], store: StoreBase, **kwargs): - self.embed = embed - self.store = store - - @classmethod - def register_similarity( - cls: "DefaultIndex", - func: Optional[Callable] = None, - mode: str = "", - descend: bool = True, - batch: bool = False, - ) -> Callable: - def decorator(f): - def wrapper(query, nodes, **kwargs): - if mode != "embedding": - if batch: - return f(query, nodes, **kwargs) - else: - return [(node, f(query, node, **kwargs)) for node in nodes] - else: - assert isinstance(query, dict), "query must be of dict type, used for similarity calculation." - similarity = {} - if batch: - for key, val in query.items(): - nodes_embed = [node.embedding[key] for node in nodes] - similarity[key] = f(val, nodes_embed, **kwargs) - else: - for key, val in query.items(): - similarity[key] = [(node, f(val, node.embedding[key], **kwargs)) for node in nodes] - return similarity - cls.registered_similarity[f.__name__] = (wrapper, mode, descend) - return wrapper - - return decorator(func) if func else decorator - - @override - def update(self, nodes: List[DocNode]) -> None: - pass - - @override - def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - pass - - @override - def query( - self, - query: str, - group_name: str, - similarity_name: str, - similarity_cut_off: Union[float, Dict[str, float]], - topk: int, - embed_keys: Optional[List[str]] = None, - **kwargs, - ) -> List[DocNode]: - if similarity_name not in self.registered_similarity: - raise ValueError( - f"{similarity_name} not registered, please check your input." - f"Available options now: {self.registered_similarity.keys()}" - ) - similarity_func, mode, descend = self.registered_similarity[similarity_name] - - nodes = self.store.get_nodes(group_name) - if mode == "embedding": - assert self.embed, "Chosen similarity needs embed model." - assert len(query) > 0, "Query should not be empty." - query_embedding = {k: self.embed[k](query) for k in (embed_keys or self.embed.keys())} - modified_nodes = parallel_do_embedding(self.embed, nodes) - self.store.update_nodes(modified_nodes) - similarities = similarity_func(query_embedding, nodes, topk=topk, **kwargs) - elif mode == "text": - similarities = similarity_func(query, nodes, topk=topk, **kwargs) - else: - raise NotImplementedError(f"Mode {mode} is not supported.") - - if not isinstance(similarities, dict): - results = self._filter_nodes_by_score(similarities, topk, similarity_cut_off, descend) - else: - results = [] - for key in (embed_keys or similarities.keys()): - sims = similarities[key] - sim_cut_off = similarity_cut_off if isinstance(similarity_cut_off, float) else similarity_cut_off[key] - results.extend(self._filter_nodes_by_score(sims, topk, sim_cut_off, descend)) - results = list(set(results)) - LOG.debug(f"Retrieving query `{query}` and get results: {results}") - return results - - def _filter_nodes_by_score(self, similarities: List[Tuple[DocNode, float]], topk: int, - similarity_cut_off: float, descend) -> List[DocNode]: - similarities.sort(key=lambda x: x[1], reverse=descend) - if topk is not None: - similarities = similarities[:topk] - - return [node for node, score in similarities if score > similarity_cut_off] - -@DefaultIndex.register_similarity(mode="text", batch=True) -def bm25(query: str, nodes: List[DocNode], **kwargs) -> List: - bm25_retriever = BM25(nodes, language="en", **kwargs) - return bm25_retriever.retrieve(query) - - -@DefaultIndex.register_similarity(mode="text", batch=True) -def bm25_chinese(query: str, nodes: List[DocNode], **kwargs) -> List: - bm25_retriever = BM25(nodes, language="zh", **kwargs) - return bm25_retriever.retrieve(query) - - -@DefaultIndex.register_similarity(mode="embedding") -def cosine(query: List[float], node: List[float], **kwargs) -> float: - product = np.dot(query, node) - norm = np.linalg.norm(query) * np.linalg.norm(node) - return product / norm - - -# User-defined similarity decorator -def register_similarity( - func: Optional[Callable] = None, - mode: str = "", - descend: bool = True, - batch: bool = False, -) -> Callable: - return DefaultIndex.register_similarity(func, mode, descend, batch) - -# ---------------------------------------------------------------------------- # - class EmbeddingIndex(IndexBase): def __init__(self, backend_type: Optional[str] = None, *args, **kwargs): if backend_type == 'milvus': diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index a0baad99..20879c2c 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -1,10 +1,13 @@ import os import shutil +import tempfile import unittest +from unittest.mock import MagicMock import lazyllm from lazyllm.tools.rag.store import LAZY_ROOT_NAME from lazyllm.tools.rag.map_store import MapStore from lazyllm.tools.rag.chroma_store import ChromadbStore +from lazyllm.tools.rag.milvus_store import MilvusStore, MilvusField from lazyllm.tools.rag.doc_node import DocNode @@ -150,7 +153,6 @@ def test_group_others(self): self.assertEqual(self.store.is_group_active("group1"), True) self.assertEqual(self.store.is_group_active("group2"), False) - class TestMilvusStore(unittest.TestCase): def setUp(self): field_list = [ From 1ad9af0e9313ed4b4f781ff6c7afbf79c74a61af Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Fri, 1 Nov 2024 11:03:45 +0800 Subject: [PATCH 27/60] s --- lazyllm/tools/rag/__init__.py | 2 +- lazyllm/tools/rag/chroma_store.py | 4 ++-- lazyllm/tools/rag/doc_impl.py | 9 +++++-- lazyllm/tools/rag/document.py | 4 +++- lazyllm/tools/rag/map_store.py | 2 +- lazyllm/tools/rag/milvus_store.py | 24 +++++++++---------- lazyllm/tools/rag/readers/docxReader.py | 2 +- lazyllm/tools/rag/readers/epubReader.py | 2 +- lazyllm/tools/rag/readers/hwpReader.py | 2 +- lazyllm/tools/rag/readers/imageReader.py | 2 +- lazyllm/tools/rag/readers/ipynbReader.py | 2 +- lazyllm/tools/rag/readers/markdownReader.py | 2 +- lazyllm/tools/rag/readers/mboxreader.py | 2 +- lazyllm/tools/rag/readers/pandasReader.py | 2 +- lazyllm/tools/rag/readers/pdfReader.py | 2 +- lazyllm/tools/rag/readers/pptxReader.py | 2 +- lazyllm/tools/rag/readers/readerBase.py | 2 +- lazyllm/tools/rag/readers/videoAudioReader.py | 2 +- tests/basic_tests/test_index.py | 4 ++-- 19 files changed, 40 insertions(+), 33 deletions(-) diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index 1b438fb4..17b25b16 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -2,7 +2,7 @@ from .retriever import Retriever from .rerank import Reranker, register_reranker from .transform import SentenceSplitter, LLMParser, NodeTransform, TransformArgs, AdaptiveTransform -from .index import register_similarity +from .default_index import register_similarity from .doc_node import DocNode from .readers import (PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader) diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index 246fc449..25a9316a 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -7,7 +7,7 @@ from .doc_node import DocNode from .store import LAZY_ROOT_NAME import json -from .map_backend import MapBackend +from .map_store import MapStore # ---------------------------------------------------------------------------- # @@ -16,7 +16,7 @@ def __init__( self, node_groups: List[str], embed_dim: Dict[str, int] ) -> None: super().__init__() - self._map_store = MapBackend(node_groups) + self._map_store = MapStore(node_groups) self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") self._collections: Dict[str, Collection] = { diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 18153a26..fb81d156 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -6,9 +6,14 @@ from lazyllm.common import override from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser, AdaptiveTransform, make_transform, TransformArgs) -from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, StoreBase +from .store import LAZY_ROOT_NAME +from .store_base import StoreBase +from .map_store import MapStore +from .chroma_store import ChromadbStore +from .doc_node import DocNode from .data_loaders import DirectoryReader -from .index import DefaultIndex, IndexBase +from .index_base import IndexBase +from .default_index import DefaultIndex from .utils import DocListManager import threading import time diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index 8bf48490..db8620b5 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -7,7 +7,9 @@ from .doc_manager import DocManager from .doc_impl import DocImpl -from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY, DocNode, StoreBase +from .doc_node import DocNode +from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY +from .store_base import StoreBase from .utils import DocListManager import copy import functools diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index ee7e2842..29cd874a 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -13,7 +13,7 @@ def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], for _, index in name2index.items(): index.remove(uids, group_name) -class MapStore(StoreBase, IndexBase): +class MapStore(StoreBase): def __init__(self, node_groups: List[str]): super().__init__() # Dict[group_name, Dict[uuid, DocNode]] diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 9537489c..6d7f9b7f 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -3,7 +3,7 @@ import pymilvus from pymilvus import MilvusClient, FieldSchema, CollectionSchema from .doc_node import DocNode -from .map_backend import MapBackend +from .map_store import MapStore from .embed_utils import parallel_do_embedding from .index_base import IndexBase from .store_base import StoreBase @@ -81,8 +81,8 @@ def __init__(self, uri: str, embed: Dict[str, Callable], self._client.create_collection(collection_name=group_name, schema=schema, index_params=index_params) - self._map_backend = MapBackend(list(group_fields.keys())) - self._load_all_nodes_to(self._map_backend) + self._map_store = MapStore(list(group_fields.keys())) + self._load_all_nodes_to(self._map_store) @override def update_nodes(self, nodes: List[DocNode]) -> None: @@ -91,7 +91,7 @@ def update_nodes(self, nodes: List[DocNode]) -> None: data = self._serialize_node_partial(node) self._client.upsert(collection_name=node.group, data=[data]) - self._map_backend.update_nodes(nodes) + self._map_store.update_nodes(nodes) @override def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: @@ -101,27 +101,27 @@ def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> Non else: self._client.drop_collection(collection_name=group_name) - self._map_backend.remove_nodes(group_name, uids) + self._map_store.remove_nodes(group_name, uids) @override def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: - return self._map_backend.get_nodes(group_name, uids) + return self._map_store.get_nodes(group_name, uids) @override def is_group_active(self, name: str) -> bool: - return self._map_backend.is_group_active(name) + return self._map_store.is_group_active(name) @override def all_groups(self) -> List[str]: - return self._map_backend.all_groups() + return self._map_store.all_groups() @override def register_index(self, type: str, index: IndexBase) -> None: - self._map_backend.register_index(type, index) + self._map_store.register_index(type, index) @override def get_index(self, type: str) -> Optional[IndexBase]: - return self._map_backend.get_index(type) + return self._map_store.get_index(type) @override def query(self, @@ -145,7 +145,7 @@ def query(self, for result in results[0]: uidset.update(result['id']) - return self._map_backend.get_nodes(group_name, list(uidset)) + return self._map_store.get_nodes(group_name, list(uidset)) # ----- internal helper functions ----- # @@ -171,7 +171,7 @@ def _load_all_nodes_to(self, store: StoreBase): for node in self.get_nodes(group): if node.parent: parent_uid = node.parent - parent_node = self._map_backend.find_node_by_uid(parent_uid) + parent_node = self._map_store.find_node_by_uid(parent_uid) node.parent = parent_node parent_node.children[node.group].append(node) diff --git a/lazyllm/tools/rag/readers/docxReader.py b/lazyllm/tools/rag/readers/docxReader.py index ff472013..7e0ab3b6 100644 --- a/lazyllm/tools/rag/readers/docxReader.py +++ b/lazyllm/tools/rag/readers/docxReader.py @@ -3,7 +3,7 @@ from typing import Dict, Optional, List from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode class DocxReader(LazyLLMReaderBase): def _load_data(self, file: Path, extra_info: Optional[Dict] = None, diff --git a/lazyllm/tools/rag/readers/epubReader.py b/lazyllm/tools/rag/readers/epubReader.py index 0e208dbf..747c3402 100644 --- a/lazyllm/tools/rag/readers/epubReader.py +++ b/lazyllm/tools/rag/readers/epubReader.py @@ -3,7 +3,7 @@ from fsspec import AbstractFileSystem from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode from lazyllm import LOG class EpubReader(LazyLLMReaderBase): diff --git a/lazyllm/tools/rag/readers/hwpReader.py b/lazyllm/tools/rag/readers/hwpReader.py index 9678336e..35f33b9c 100644 --- a/lazyllm/tools/rag/readers/hwpReader.py +++ b/lazyllm/tools/rag/readers/hwpReader.py @@ -5,7 +5,7 @@ import zlib from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode from lazyllm import LOG class HWPReader(LazyLLMReaderBase): diff --git a/lazyllm/tools/rag/readers/imageReader.py b/lazyllm/tools/rag/readers/imageReader.py index ee610bbc..fe05f57f 100644 --- a/lazyllm/tools/rag/readers/imageReader.py +++ b/lazyllm/tools/rag/readers/imageReader.py @@ -7,7 +7,7 @@ from PIL import Image from .readerBase import LazyLLMReaderBase, infer_torch_device -from ..store import DocNode +from ..doc_node import DocNode def img_2_b64(image: Image, format: str = "JPEG") -> str: buff = BytesIO() diff --git a/lazyllm/tools/rag/readers/ipynbReader.py b/lazyllm/tools/rag/readers/ipynbReader.py index 66c0e192..90e0cc5c 100644 --- a/lazyllm/tools/rag/readers/ipynbReader.py +++ b/lazyllm/tools/rag/readers/ipynbReader.py @@ -4,7 +4,7 @@ from fsspec import AbstractFileSystem from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode class IPYNBReader(LazyLLMReaderBase): def __init__(self, parser_config: Optional[Dict] = None, concatenate: bool = False, return_trace: bool = True): diff --git a/lazyllm/tools/rag/readers/markdownReader.py b/lazyllm/tools/rag/readers/markdownReader.py index c1748f55..0184576b 100644 --- a/lazyllm/tools/rag/readers/markdownReader.py +++ b/lazyllm/tools/rag/readers/markdownReader.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Tuple from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode class MarkdownReader(LazyLLMReaderBase): def __init__(self, remove_hyperlinks: bool = True, remove_images: bool = True, return_trace: bool = True) -> None: diff --git a/lazyllm/tools/rag/readers/mboxreader.py b/lazyllm/tools/rag/readers/mboxreader.py index 567854c8..3fab832c 100644 --- a/lazyllm/tools/rag/readers/mboxreader.py +++ b/lazyllm/tools/rag/readers/mboxreader.py @@ -3,7 +3,7 @@ from fsspec import AbstractFileSystem from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode from lazyllm import LOG class MboxReader(LazyLLMReaderBase): diff --git a/lazyllm/tools/rag/readers/pandasReader.py b/lazyllm/tools/rag/readers/pandasReader.py index bbe2cb60..e3ad327a 100644 --- a/lazyllm/tools/rag/readers/pandasReader.py +++ b/lazyllm/tools/rag/readers/pandasReader.py @@ -5,7 +5,7 @@ import pandas as pd from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode class PandasCSVReader(LazyLLMReaderBase): def __init__(self, concat_rows: bool = True, col_joiner: str = ", ", row_joiner: str = "\n", diff --git a/lazyllm/tools/rag/readers/pdfReader.py b/lazyllm/tools/rag/readers/pdfReader.py index 8982c424..a0a4043f 100644 --- a/lazyllm/tools/rag/readers/pdfReader.py +++ b/lazyllm/tools/rag/readers/pdfReader.py @@ -5,7 +5,7 @@ from fsspec import AbstractFileSystem from .readerBase import LazyLLMReaderBase, get_default_fs, is_default_fs -from ..store import DocNode +from ..doc_node import DocNode RETRY_TIMES = 3 diff --git a/lazyllm/tools/rag/readers/pptxReader.py b/lazyllm/tools/rag/readers/pptxReader.py index 8085844d..3ae216e8 100644 --- a/lazyllm/tools/rag/readers/pptxReader.py +++ b/lazyllm/tools/rag/readers/pptxReader.py @@ -5,7 +5,7 @@ from typing import Optional, Dict, List from .readerBase import LazyLLMReaderBase, infer_torch_device -from ..store import DocNode +from ..doc_node import DocNode class PPTXReader(LazyLLMReaderBase): def __init__(self, return_trace: bool = True) -> None: diff --git a/lazyllm/tools/rag/readers/readerBase.py b/lazyllm/tools/rag/readers/readerBase.py index 70515e52..3e9d0355 100644 --- a/lazyllm/tools/rag/readers/readerBase.py +++ b/lazyllm/tools/rag/readers/readerBase.py @@ -3,7 +3,7 @@ from typing import Iterable, List from ....common import LazyLLMRegisterMetaClass -from ..store import DocNode +from ..doc_node import DocNode from lazyllm.module import ModuleBase class LazyLLMReaderBase(ModuleBase, metaclass=LazyLLMRegisterMetaClass): diff --git a/lazyllm/tools/rag/readers/videoAudioReader.py b/lazyllm/tools/rag/readers/videoAudioReader.py index 02236e75..bdd41e1d 100644 --- a/lazyllm/tools/rag/readers/videoAudioReader.py +++ b/lazyllm/tools/rag/readers/videoAudioReader.py @@ -3,7 +3,7 @@ from fsspec import AbstractFileSystem from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode class VideoAudioReader(LazyLLMReaderBase): def __init__(self, model_version: str = "base", return_trace: bool = True) -> None: diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index 4b306744..3cd29447 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -1,9 +1,9 @@ import time import unittest from unittest.mock import MagicMock -from lazyllm.tools.rag.store import MapStore +from lazyllm.tools.rag.map_store import MapStore from lazyllm.tools.rag.doc_node import DocNode -from lazyllm.tools.rag.index import DefaultIndex, register_similarity +from lazyllm.tools.rag.default_index import DefaultIndex, register_similarity from lazyllm.tools.rag.embed_utils import parallel_do_embedding class TestDefaultIndex(unittest.TestCase): From c2316f4585495cf4ef0fadba41957cc3583b7c9e Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Fri, 1 Nov 2024 11:19:35 +0800 Subject: [PATCH 28/60] s --- lazyllm/tools/rag/chroma_store.py | 17 ++++++++++++++++- lazyllm/tools/rag/default_index.py | 2 -- lazyllm/tools/rag/index.py | 24 ------------------------ lazyllm/tools/rag/map_store.py | 6 ++++-- lazyllm/tools/rag/milvus_store.py | 5 ++++- lazyllm/tools/rag/store_base.py | 2 +- 6 files changed, 25 insertions(+), 31 deletions(-) diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index 25a9316a..a231dad3 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -6,6 +6,7 @@ from .store_base import StoreBase from .doc_node import DocNode from .store import LAZY_ROOT_NAME +from .index_base import IndexBase import json from .map_store import MapStore @@ -15,7 +16,6 @@ class ChromadbStore(StoreBase): def __init__( self, node_groups: List[str], embed_dim: Dict[str, int] ) -> None: - super().__init__() self._map_store = MapStore(node_groups) self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") @@ -24,6 +24,7 @@ def __init__( for group in node_groups } self._embed_dim = embed_dim + self._name2index = {} @override def update_nodes(self, nodes: List[DocNode]) -> None: @@ -50,6 +51,20 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: return self._map_store.all_groups() + @override + def query(self, *args, **kwargs) -> List[DocNode]: + raise NotImplementedError('not implemented yet.') + + @override + def register_index(self, type: str, index: IndexBase) -> None: + self._name2index[type] = index + + @override + def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]: + if type is None: + type = 'default' + return self._name2index.get(type) + def _load_store(self) -> None: if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: LOG.info("No persistent data found, skip the rebuilding phrase.") diff --git a/lazyllm/tools/rag/default_index.py b/lazyllm/tools/rag/default_index.py index 61bbda59..225983ea 100644 --- a/lazyllm/tools/rag/default_index.py +++ b/lazyllm/tools/rag/default_index.py @@ -7,8 +7,6 @@ from lazyllm import LOG from lazyllm.common import override from .embed_utils import parallel_do_embedding -from .milvus_store import MilvusStore -from .map_store import MapStore # ---------------------------------------------------------------------------- # diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index 5f254efa..dcae2b50 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -3,30 +3,6 @@ from .store_base import StoreBase from .index_base import IndexBase from lazyllm.common import override -from .milvus_store import MilvusStore -from .map_store import MapStore - -class EmbeddingIndex(IndexBase): - def __init__(self, backend_type: Optional[str] = None, *args, **kwargs): - if backend_type == 'milvus': - self._store = MilvusStore(*args, **kwargs) - elif backend_type == 'map': - self._store = MapStore(*args, **kwargs) - else: - raise ValueError(f'unsupported IndexWrapper backend [{backend_type}]') - - @override - def update(self, nodes: List[DocNode]) -> None: - self._store.update_nodes(nodes) - - @override - def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - self._store.remove_nodes(group_name, uids) - - @override - def query(self, *args, **kwargs) -> List[DocNode]: - return self._store.query(*args, **kwargs) - class WrapStoreToIndex(IndexBase): def __init__(self, store: StoreBase): diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index 29cd874a..c99c1a1a 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -2,6 +2,7 @@ from .index_base import IndexBase from .store_base import StoreBase from .doc_node import DocNode +from .index import WrapStoreToIndex from lazyllm.common import override def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: @@ -15,7 +16,6 @@ def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], class MapStore(StoreBase): def __init__(self, node_groups: List[str]): - super().__init__() # Dict[group_name, Dict[uuid, DocNode]] self._group2docs: Dict[str, Dict[str, DocNode]] = { group: {} for group in node_groups @@ -71,7 +71,9 @@ def register_index(self, type: str, index: IndexBase) -> None: self._name2index[type] = index @override - def get_index(self, type: str) -> Optional[IndexBase]: + def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]: + if type is None or type == 'default': + return WrapStoreToIndex(self) return self._name2index.get(type) @override diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 6d7f9b7f..080d8cfa 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -6,6 +6,7 @@ from .map_store import MapStore from .embed_utils import parallel_do_embedding from .index_base import IndexBase +from .index import WrapStoreToIndex from .store_base import StoreBase from lazyllm.common import override @@ -120,7 +121,9 @@ def register_index(self, type: str, index: IndexBase) -> None: self._map_store.register_index(type, index) @override - def get_index(self, type: str) -> Optional[IndexBase]: + def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]: + if type is None or type == 'default': + return WrapStoreToIndex(self) return self._map_store.get_index(type) @override diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index 515aad80..36b23b98 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -33,5 +33,5 @@ def register_index(self, type: str, index: IndexBase) -> None: pass @abstractmethod - def get_index(self, type: str) -> Optional[IndexBase]: + def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]: pass From efcc5ab20404e227d12856149c8e3b6fc2d9958b Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Fri, 1 Nov 2024 11:28:34 +0800 Subject: [PATCH 29/60] s --- lazyllm/tools/rag/map_store.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index c99c1a1a..202ec039 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -72,13 +72,13 @@ def register_index(self, type: str, index: IndexBase) -> None: @override def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]: - if type is None or type == 'default': - return WrapStoreToIndex(self) + if type is None: + type = 'default' return self._name2index.get(type) @override - def query(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: - return self.get_nodes(group_name, uids) + def query(self, *args, **kwargs) -> List[DocNode]: + raise NotImplementedError('not implemented yet.') def find_node_by_uid(self, uid: str) -> Optional[DocNode]: for docs in self._group2docs.values(): From 7d6641a673010d7584dfcf8c315a3df1977f5399 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Fri, 1 Nov 2024 11:41:57 +0800 Subject: [PATCH 30/60] s --- tests/basic_tests/test_document.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index 669040bb..9ce0d95f 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -1,5 +1,5 @@ import lazyllm -from lazyllm.tools.rag.doc_impl import DocImpl, FileNodeIndex +from lazyllm.tools.rag.doc_impl import DocImpl, _FileNodeIndex from lazyllm.tools.rag.transform import SentenceSplitter from lazyllm.tools.rag.store import LAZY_ROOT_NAME from lazyllm.tools.rag.doc_node import DocNode @@ -155,7 +155,7 @@ def test_multi_embedding_with_document(self): class TestFileNodeIndex(unittest.TestCase): def setUp(self): - self.index = FileNodeIndex() + self.index = _FileNodeIndex() self.node1 = DocNode(uid='1', group=LAZY_ROOT_NAME, metadata={"file_name": "d1"}) self.node2 = DocNode(uid='2', group=LAZY_ROOT_NAME, metadata={"file_name": "d2"}) self.files = [self.node1.metadata['file_name'], self.node1.metadata['file_name']] From a3155c96cc55dfcdbd4bfb0e2680a0a9b6c6204a Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Fri, 1 Nov 2024 16:00:05 +0800 Subject: [PATCH 31/60] s --- lazyllm/tools/rag/chroma_store.py | 12 ++++++-- lazyllm/tools/rag/doc_impl.py | 51 ++++++++----------------------- lazyllm/tools/rag/embed_utils.py | 4 +-- lazyllm/tools/rag/index.py | 13 ++++++-- lazyllm/tools/rag/map_store.py | 23 ++++++++------ lazyllm/tools/rag/milvus_store.py | 4 +-- lazyllm/tools/rag/store.py | 6 ---- lazyllm/tools/rag/store_base.py | 10 ++++++ lazyllm/tools/rag/utils.py | 27 ++++++++++++++++ 9 files changed, 85 insertions(+), 65 deletions(-) delete mode 100644 lazyllm/tools/rag/store.py diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index a231dad3..aa27e9ee 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -7,6 +7,8 @@ from .doc_node import DocNode from .store import LAZY_ROOT_NAME from .index_base import IndexBase +from .utils import _FileNodeIndex +from .default_index import DefaultIndex import json from .map_store import MapStore @@ -14,7 +16,7 @@ class ChromadbStore(StoreBase): def __init__( - self, node_groups: List[str], embed_dim: Dict[str, int] + self, node_groups: List[str], embed: Dict[str, Callable], embed_dim: Dict[str, int] ) -> None: self._map_store = MapStore(node_groups) self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) @@ -24,7 +26,11 @@ def __init__( for group in node_groups } self._embed_dim = embed_dim - self._name2index = {} + + self._name2index = { + 'default': DefaultIndex(embed, self._map_store), + 'file_node_map': _FileNodeIndex(), + } @override def update_nodes(self, nodes: List[DocNode]) -> None: @@ -53,7 +59,7 @@ def all_groups(self) -> List[str]: @override def query(self, *args, **kwargs) -> List[DocNode]: - raise NotImplementedError('not implemented yet.') + return get_index('default').query(*args, **kwargs) @override def register_index(self, type: str, index: IndexBase) -> None: diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index fb81d156..4aec4466 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -22,32 +22,6 @@ # ---------------------------------------------------------------------------- # -class _FileNodeIndex(IndexBase): - def __init__(self): - self._file_node_map = {} - - @override - def update(self, nodes: List[DocNode]) -> None: - for node in nodes: - if node.group != LAZY_ROOT_NAME: - continue - file_name = node.metadata.get("file_name") - if file_name: - self._file_node_map[file_name] = node - - @override - def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: - # group_name is ignored - left = {k: v for k, v in self._file_node_map.items() if v.uid not in uids} - self._file_node_map = left - - @override - def query(self, files: List[str]) -> List[DocNode]: - ret = [] - for file in files: - ret.append(self._file_node_map.get(file)) - return ret - class _DocStore(StoreBase): @staticmethod def _create_file_node_index(store) -> _FileNodeIndex: @@ -58,23 +32,21 @@ def _create_file_node_index(store) -> _FileNodeIndex: @staticmethod def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: - for _, index in name2index.items(): + for index in name2index.values(): index.update(nodes) @staticmethod def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], group_name: Optional[str] = None) -> None: - for _, index in name2index.items(): + for index in name2index.values(): index.remove(uids, group_name) - def _create_some_indices(self): - if not self._store.get_index(type='file_node_map'): - self.register_index(type='file_node_map', index=self._create_file_node_index(self._store)) - def __init__(self, store: StoreBase): self._store = store self._extra_indices = {} - self._create_some_indices() + + if not self._store.get_index(type='file_node_map'): + self.register_index(type='file_node_map', index=self._create_file_node_index(self._store)) def update_nodes(self, nodes: List[DocNode]) -> None: self._store.update_nodes(nodes) @@ -169,17 +141,13 @@ def _create_store(self, rag_store_type: str = None) -> StoreBase: if rag_store_type == "map": store = MapStore(node_groups=self.node_groups.keys()) elif rag_store_type == "chroma": - store = ChromadbStore(node_groups=self.node_groups.keys(), embed_dim=self._embed_dim) + store = ChromadbStore(node_groups=self.node_groups.keys(), + embed=self.embed, embed_dim=self._embed_dim) else: raise NotImplementedError( f"Not implemented store type for {rag_store_type}" ) - if not store.get_index(type='default'): - store.register_index(type='default', index=DefaultIndex(self.embed, store)) - if not store.get_index(type='file_node_map'): - store.register_index(type='file_node_map', index=self._create_file_node_index(store)) - return store @staticmethod @@ -346,6 +314,11 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_ index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]: self._lazy_init() + if type is None or type == 'default': + return self.store.query(query=query, group_name=group_name, similarity_name=similarity, + similarity_cut_off=similarity_cut_off, topk=topk, + embed_keys=embed_keys, **similarity_kws) + index_instance = self.store.get_index(type=index) if not index_instance: raise NotImplementedError(f"index type '{index}' is not supported currently.") diff --git a/lazyllm/tools/rag/embed_utils.py b/lazyllm/tools/rag/embed_utils.py index a4fb03c3..008accbb 100644 --- a/lazyllm/tools/rag/embed_utils.py +++ b/lazyllm/tools/rag/embed_utils.py @@ -12,10 +12,8 @@ "MAX_EMBEDDING_WORKERS", ) +# returns a list of modified nodes def parallel_do_embedding(embed: Dict[str, Callable], nodes: List[DocNode]) -> List[DocNode]: - ''' - returns a list of modified nodes - ''' modified_nodes = [] with ThreadPoolExecutor(config["max_embedding_workers"]) as executor: futures = [] diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index dcae2b50..72f37445 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -3,10 +3,17 @@ from .store_base import StoreBase from .index_base import IndexBase from lazyllm.common import override +from .map_store import MapStore +from .milvus_store import MilvusStore, MilvusField -class WrapStoreToIndex(IndexBase): - def __init__(self, store: StoreBase): - self._store = store +class SmartEmbeddingIndex(IndexBase): + def __init__(self, backend_type: str, fields: List[str], *args, **kwargs): + if backend_type == 'milvus': + self._store = MilvusStore(*args, **kwargs) + elif backend_type == 'map': + self._store = MapStore(*args, **kwargs) + else: + raise ValueError(f'unsupported backend [{backend_type}]') @override def update(self, nodes: List[DocNode]) -> None: diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index 202ec039..e00dd71a 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -2,25 +2,30 @@ from .index_base import IndexBase from .store_base import StoreBase from .doc_node import DocNode -from .index import WrapStoreToIndex +from .utils import _FileNodeIndex +from .default_index import DefaultIndex from lazyllm.common import override def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: - for _, index in name2index.items(): + for index in name2index.values(): index.update(nodes) def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], group_name: Optional[str] = None) -> None: - for _, index in name2index.items(): + for index in name2index.values(): index.remove(uids, group_name) class MapStore(StoreBase): - def __init__(self, node_groups: List[str]): + def __init__(self, node_groups: List[str], embed: Dict[str, Callable]): # Dict[group_name, Dict[uuid, DocNode]] self._group2docs: Dict[str, Dict[str, DocNode]] = { group: {} for group in node_groups } - self._name2index = {} + + self._name2index = { + 'default': DefaultIndex(embed, self), + 'file_node_map': _FileNodeIndex(), + } @override def update_nodes(self, nodes: List[DocNode]) -> None: @@ -66,6 +71,10 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: return self._group2docs.keys() + @override + def query(self, *args, **kwargs) -> List[DocNode]: + return get_index('default').query(*args, **kwargs) + @override def register_index(self, type: str, index: IndexBase) -> None: self._name2index[type] = index @@ -76,10 +85,6 @@ def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]: type = 'default' return self._name2index.get(type) - @override - def query(self, *args, **kwargs) -> List[DocNode]: - raise NotImplementedError('not implemented yet.') - def find_node_by_uid(self, uid: str) -> Optional[DocNode]: for docs in self._group2docs.values(): doc = docs.get(uid) diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 080d8cfa..1314d5c5 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -122,8 +122,8 @@ def register_index(self, type: str, index: IndexBase) -> None: @override def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]: - if type is None or type == 'default': - return WrapStoreToIndex(self) + if type is None: + type = 'default' return self._map_store.get_index(type) @override diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py deleted file mode 100644 index 8db832c1..00000000 --- a/lazyllm/tools/rag/store.py +++ /dev/null @@ -1,6 +0,0 @@ -from lazyllm import config - -LAZY_ROOT_NAME = "lazyllm_root" -EMBED_DEFAULT_KEY = '__default__' -config.add("rag_store_type", str, "map", "RAG_STORE_TYPE") # "map", "chroma" -config.add("rag_persistent_path", str, "./lazyllm_chroma", "RAG_PERSISTENT_PATH") diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index 36b23b98..23e02f41 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -1,8 +1,18 @@ from abc import ABC, abstractmethod from typing import Optional, List +from lazyllm import config from .doc_node import DocNode from .index_base import IndexBase +# ---------------------------------------------------------------------------- # + +LAZY_ROOT_NAME = "lazyllm_root" +EMBED_DEFAULT_KEY = '__default__' +config.add("rag_store_type", str, "map", "RAG_STORE_TYPE") # "map", "chroma" +config.add("rag_persistent_path", str, "./lazyllm_chroma", "RAG_PERSISTENT_PATH") + +# ---------------------------------------------------------------------------- # + class StoreBase(ABC): @abstractmethod def update_nodes(self, nodes: List[DocNode]) -> None: diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index 66338ada..47f8e88d 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -472,3 +472,30 @@ def save_files_in_threads( if os.path.exists(cache_dir): shutil.rmtree(cache_dir) return (already_exist_files, new_add_files, overwritten_files) + + +class _FileNodeIndex(IndexBase): + def __init__(self): + self._file_node_map = {} + + @override + def update(self, nodes: List[DocNode]) -> None: + for node in nodes: + if node.group != LAZY_ROOT_NAME: + continue + file_name = node.metadata.get("file_name") + if file_name: + self._file_node_map[file_name] = node + + @override + def remove(self, uids: List[str], group_name: Optional[str] = None) -> None: + # group_name is ignored + left = {k: v for k, v in self._file_node_map.items() if v.uid not in uids} + self._file_node_map = left + + @override + def query(self, files: List[str]) -> List[DocNode]: + ret = [] + for file in files: + ret.append(self._file_node_map.get(file)) + return ret From bbec2376ed94870b42d697d648719a2e908ce007 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Fri, 1 Nov 2024 16:50:52 +0800 Subject: [PATCH 32/60] s --- lazyllm/tools/rag/chroma_store.py | 4 ++-- lazyllm/tools/rag/doc_impl.py | 11 ++++------- lazyllm/tools/rag/map_store.py | 4 ++-- lazyllm/tools/rag/milvus_store.py | 1 - lazyllm/tools/rag/retriever.py | 2 +- .../tools/rag/{index.py => smart_embedding_index.py} | 3 +-- lazyllm/tools/rag/utils.py | 4 ++++ 7 files changed, 14 insertions(+), 15 deletions(-) rename lazyllm/tools/rag/{index.py => smart_embedding_index.py} (91%) diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index aa27e9ee..a8a47c72 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Callable import chromadb from lazyllm import LOG, config from lazyllm.common import override @@ -59,7 +59,7 @@ def all_groups(self) -> List[str]: @override def query(self, *args, **kwargs) -> List[DocNode]: - return get_index('default').query(*args, **kwargs) + return self.get_index('default').query(*args, **kwargs) @override def register_index(self, type: str, index: IndexBase) -> None: diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 4aec4466..9f74d392 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -3,7 +3,6 @@ from functools import wraps from typing import Callable, Dict, List, Optional, Set, Union, Tuple from lazyllm import LOG, config, once_wrapper -from lazyllm.common import override from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser, AdaptiveTransform, make_transform, TransformArgs) from .store import LAZY_ROOT_NAME @@ -13,8 +12,7 @@ from .doc_node import DocNode from .data_loaders import DirectoryReader from .index_base import IndexBase -from .default_index import DefaultIndex -from .utils import DocListManager +from .utils import DocListManager, _FileNodeIndex import threading import time @@ -43,10 +41,9 @@ def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], def __init__(self, store: StoreBase): self._store = store - self._extra_indices = {} - - if not self._store.get_index(type='file_node_map'): - self.register_index(type='file_node_map', index=self._create_file_node_index(self._store)) + self._extra_indices = { + 'file_node_map': self._create_file_node_index(self._store) + } def update_nodes(self, nodes: List[DocNode]) -> None: self._store.update_nodes(nodes) diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index e00dd71a..447721aa 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Callable from .index_base import IndexBase from .store_base import StoreBase from .doc_node import DocNode @@ -73,7 +73,7 @@ def all_groups(self) -> List[str]: @override def query(self, *args, **kwargs) -> List[DocNode]: - return get_index('default').query(*args, **kwargs) + return self.get_index('default').query(*args, **kwargs) @override def register_index(self, type: str, index: IndexBase) -> None: diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 1314d5c5..72e33bbc 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -6,7 +6,6 @@ from .map_store import MapStore from .embed_utils import parallel_do_embedding from .index_base import IndexBase -from .index import WrapStoreToIndex from .store_base import StoreBase from lazyllm.common import override diff --git a/lazyllm/tools/rag/retriever.py b/lazyllm/tools/rag/retriever.py index 0acf2883..0bfee49a 100644 --- a/lazyllm/tools/rag/retriever.py +++ b/lazyllm/tools/rag/retriever.py @@ -33,7 +33,7 @@ def __init__( self, doc: object, group_name: str, - similarity: str = "dummy", + similarity: Optional[str] = None, similarity_cut_off: Union[float, Dict[str, float]] = float("-inf"), index: str = "default", topk: int = 6, diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/smart_embedding_index.py similarity index 91% rename from lazyllm/tools/rag/index.py rename to lazyllm/tools/rag/smart_embedding_index.py index 72f37445..c975991c 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/smart_embedding_index.py @@ -1,10 +1,9 @@ from typing import List, Optional from .doc_node import DocNode -from .store_base import StoreBase from .index_base import IndexBase from lazyllm.common import override from .map_store import MapStore -from .milvus_store import MilvusStore, MilvusField +from .milvus_store import MilvusStore class SmartEmbeddingIndex(IndexBase): def __init__(self, backend_type: str, fields: List[str], *args, **kwargs): diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index 47f8e88d..2a8d3079 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -3,6 +3,10 @@ import hashlib from typing import List, Callable, Generator, Dict, Any, Optional, Union, Tuple from abc import ABC, abstractmethod +from .index_base import IndexBase +from .store_base import LAZY_ROOT_NAME +from .doc_node import DocNode +from lazyllm.common import override import pydantic import sqlite3 From d90333fe6fe714f895d0857e848157262043f9a2 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Fri, 1 Nov 2024 19:14:14 +0800 Subject: [PATCH 33/60] s --- lazyllm/tools/rag/chroma_store.py | 11 ++-- lazyllm/tools/rag/doc_impl.py | 96 +++++++------------------------ lazyllm/tools/rag/document.py | 22 +++---- lazyllm/tools/rag/milvus_store.py | 3 + lazyllm/tools/rag/store_base.py | 3 - 5 files changed, 42 insertions(+), 93 deletions(-) diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index a8a47c72..fe5bc0cf 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional, Callable import chromadb -from lazyllm import LOG, config +from lazyllm import LOG from lazyllm.common import override from chromadb.api.models.Collection import Collection from .store_base import StoreBase @@ -15,12 +15,11 @@ # ---------------------------------------------------------------------------- # class ChromadbStore(StoreBase): - def __init__( - self, node_groups: List[str], embed: Dict[str, Callable], embed_dim: Dict[str, int] - ) -> None: + def __init__(self, node_groups: List[str], path: str, embed: Dict[str, Callable], + embed_dim: Dict[str, int]) -> None: self._map_store = MapStore(node_groups) - self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) - LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") + self._db_client = chromadb.PersistentClient(path=path) + LOG.success(f"Initialzed chromadb in path: {path}") self._collections: Dict[str, Collection] = { group: self._db_client.get_or_create_collection(group) for group in node_groups diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 9f74d392..f6f587a9 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -2,7 +2,7 @@ from collections import defaultdict from functools import wraps from typing import Callable, Dict, List, Optional, Set, Union, Tuple -from lazyllm import LOG, config, once_wrapper +from lazyllm import LOG, once_wrapper from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser, AdaptiveTransform, make_transform, TransformArgs) from .store import LAZY_ROOT_NAME @@ -11,68 +11,12 @@ from .chroma_store import ChromadbStore from .doc_node import DocNode from .data_loaders import DirectoryReader -from .index_base import IndexBase -from .utils import DocListManager, _FileNodeIndex +from .utils import DocListManager import threading import time _transmap = dict(function=FuncNodeTransform, sentencesplitter=SentenceSplitter, llm=LLMParser) -# ---------------------------------------------------------------------------- # - -class _DocStore(StoreBase): - @staticmethod - def _create_file_node_index(store) -> _FileNodeIndex: - index = _FileNodeIndex() - for group in store.all_groups(): - index.update(store.get_nodes(group)) - return index - - @staticmethod - def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: - for index in name2index.values(): - index.update(nodes) - - @staticmethod - def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], - group_name: Optional[str] = None) -> None: - for index in name2index.values(): - index.remove(uids, group_name) - - def __init__(self, store: StoreBase): - self._store = store - self._extra_indices = { - 'file_node_map': self._create_file_node_index(self._store) - } - - def update_nodes(self, nodes: List[DocNode]) -> None: - self._store.update_nodes(nodes) - self._update_indices(self._extra_indices, nodes) - - def get_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> List[DocNode]: - return self._store.get_nodes(group_name, uids) - - def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: - self._store.remove_nodes(group_name, uids) - self._remove_from_indices(self._extra_indices, uids, group_name) - - def is_group_active(self, name: str) -> bool: - return self._store.is_group_active(name) - - def all_groups(self) -> List[str]: - return self._store.all_groups() - - def register_index(self, type: str, index: IndexBase) -> None: - self._extra_indices[type] = index - - def get_index(self, type: str = 'default') -> Optional[IndexBase]: - index = self._extra_indices.get(type) - if not index: - index = self._store.get_index(type) - return index - -# ---------------------------------------------------------------------------- # - def embed_wrapper(func): if not func: return None @@ -92,7 +36,7 @@ class DocImpl: def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = None, doc_files: Optional[str] = None, kb_group_name: Optional[str] = None, - store: Optional[StoreBase] = None): + store_conf: Optional[Dict] = None): super().__init__() assert (dlm is None) ^ (doc_files is None), 'Only one of dataset_path or doc_files should be provided' self._local_file_reader: Dict[str, Callable] = {} @@ -102,10 +46,7 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N self.node_groups: Dict[str, Dict] = {LAZY_ROOT_NAME: {}} self.embed = {k: embed_wrapper(e) for k, e in embed.items()} self._embed_dim = None - if store: - self.store = _DocStore(store) - else: - self.store = None + self.store = store_conf # NOTE: will be initialized in _lazy_init() @once_wrapper(reset_on_pickle=True) def _lazy_init(self) -> None: @@ -116,8 +57,10 @@ def _lazy_init(self) -> None: self._embed_dim = {k: len(e('a')) for k, e in self.embed.items()} - if not self.store: - self.store = self._create_store() + if isinstance(self.store, Dict): + self.store = self._create_store(self.store) + else: + raise ValueError(f'store type [{type(self.store)}] is not a dict.') if not self.store.is_group_active(LAZY_ROOT_NAME): ids, pathes = self._list_files() @@ -132,17 +75,22 @@ def _lazy_init(self) -> None: self._daemon.daemon = True self._daemon.start() - def _create_store(self, rag_store_type: str = None) -> StoreBase: - if not rag_store_type: - rag_store_type = config["rag_store_type"] - if rag_store_type == "map": - store = MapStore(node_groups=self.node_groups.keys()) - elif rag_store_type == "chroma": - store = ChromadbStore(node_groups=self.node_groups.keys(), - embed=self.embed, embed_dim=self._embed_dim) + def _create_store(self, store_conf: Optional[Dict]) -> StoreBase: + type = store_conf.get('type') + if not type: + raise ValueError('store type is not specified.') + + kwargs = store_conf.get('kwargs') + if not isinstance(kwargs, Dict): + raise ValueError('`kwargs` in store conf is not a dict.') + + if type == "map": + store = MapStore(embed=self.embed, **kwargs) + elif type == "chroma": + store = ChromadbStore(embed=self.embed, **kwargs) else: raise NotImplementedError( - f"Not implemented store type for {rag_store_type}" + f"Not implemented store type for {type}" ) return store diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index d62688c8..f5a01214 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -9,7 +9,6 @@ from .doc_impl import DocImpl from .doc_node import DocNode from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY -from .store_base import StoreBase from .utils import DocListManager import copy import functools @@ -23,7 +22,7 @@ class Document(ModuleBase): class _Impl(ModuleBase): def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, manager: bool = False, server: bool = False, name: Optional[str] = None, - launcher: Launcher = None, store: StoreBase = None): + launcher: Optional[Launcher] = None, store_conf: Optional[Dict] = None): super().__init__() if not os.path.exists(dataset_path): defatult_path = os.path.join(lazyllm.config["data_path"], dataset_path) @@ -38,15 +37,17 @@ def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, self._submodules.append(embed) self._dlm = DocListManager(dataset_path, name).init_tables() self._kbs = CallableDict({DocListManager.DEDAULT_GROUP_NAME: - DocImpl(embed=self._embed, dlm=self._dlm, store=store)}) + DocImpl(embed=self._embed, dlm=self._dlm, store_conf=store_conf)}) if manager: self._manager = ServerModule(DocManager(self._dlm)) if server: self._kbs = ServerModule(self._kbs) - def add_kb_group(self, name, store: StoreBase): + def add_kb_group(self, name, store_conf: Optional[Dict] = None): if isinstance(self._kbs, ServerModule): - self._kbs._impl._m[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name, store=store) + self._kbs._impl._m[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name, + store_conf=store_conf) else: - self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name, store=store) + self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name, + store_conf=store_conf) self._dlm.add_kb_group(name) def get_doc_by_kb_group(self, name): @@ -59,15 +60,16 @@ def __call__(self, *args, **kw): def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, create_ui: bool = False, manager: bool = False, server: bool = False, - name: Optional[str] = None, launcher=None, store: StoreBase = None): + name: Optional[str] = None, launcher: Optional[Launcher] = None, + store_conf: Optional[Dict] = None): super().__init__() if create_ui: lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead') - self._impls = Document._Impl(dataset_path, embed, create_ui or manager, server, name, launcher, store) + self._impls = Document._Impl(dataset_path, embed, create_ui or manager, server, name, launcher, store_conf) self._curr_group = DocListManager.DEDAULT_GROUP_NAME - def create_kb_group(self, name: str, store: StoreBase) -> "Document": - self._impls.add_kb_group(name, store) + def create_kb_group(self, name: str, store_conf: Optional[Dict] = None) -> "Document": + self._impls.add_kb_group(name, store_conf) doc = copy.copy(self) doc._curr_group = name return doc diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 72e33bbc..60795f3a 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -134,6 +134,9 @@ def query(self, topk: int = 10, embed_keys: Optional[List[str]] = None, **kwargs) -> List[DocNode]: + if similarity is not None: + raise ValueError('`similarity` MUST be None when Milvus backend is used.') + uidset = set() for key in embed_keys: embed_func = self._embed.get(key) diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index 23e02f41..08992ad9 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from typing import Optional, List -from lazyllm import config from .doc_node import DocNode from .index_base import IndexBase @@ -8,8 +7,6 @@ LAZY_ROOT_NAME = "lazyllm_root" EMBED_DEFAULT_KEY = '__default__' -config.add("rag_store_type", str, "map", "RAG_STORE_TYPE") # "map", "chroma" -config.add("rag_persistent_path", str, "./lazyllm_chroma", "RAG_PERSISTENT_PATH") # ---------------------------------------------------------------------------- # From dcac229ff66a03c0961e35b49e2cb93510f7908a Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Fri, 1 Nov 2024 19:55:59 +0800 Subject: [PATCH 34/60] s --- lazyllm/tools/rag/doc_impl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index f6f587a9..519d0ab2 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -88,6 +88,8 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase: store = MapStore(embed=self.embed, **kwargs) elif type == "chroma": store = ChromadbStore(embed=self.embed, **kwargs) + elif type == "milvus": + store = MilvusStore(embed=self.embed, **kwargs) else: raise NotImplementedError( f"Not implemented store type for {type}" From a534624e24d98ba7aa7e6db51ec998a44c0a59fa Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 10:49:55 +0800 Subject: [PATCH 35/60] s --- lazyllm/__init__.py | 4 +--- lazyllm/tools/__init__.py | 4 +--- lazyllm/tools/rag/__init__.py | 8 ++++---- lazyllm/tools/rag/chroma_store.py | 9 ++++----- lazyllm/tools/rag/data_loaders.py | 2 +- lazyllm/tools/rag/doc_impl.py | 12 ++++++++++-- lazyllm/tools/rag/document.py | 2 +- lazyllm/tools/rag/milvus_store.py | 7 +++---- 8 files changed, 25 insertions(+), 23 deletions(-) diff --git a/lazyllm/__init__.py b/lazyllm/__init__.py index 72fce42d..5ba10262 100644 --- a/lazyllm/__init__.py +++ b/lazyllm/__init__.py @@ -15,7 +15,7 @@ from .client import redis_client from .tools import (Document, Reranker, Retriever, WebModule, ToolManager, FunctionCall, FunctionCallAgent, fc_register, ReactAgent, PlanAndSolveAgent, ReWOOAgent, SentenceSplitter, - LLMParser, StoreBase, IndexBase) + LLMParser) from .docs import add_doc config.done() @@ -73,8 +73,6 @@ 'PlanAndSolveAgent', 'ReWOOAgent', 'SentenceSplitter', - 'StoreBase', - 'IndexBase', # docs 'add_doc', diff --git a/lazyllm/tools/__init__.py b/lazyllm/tools/__init__.py index 31eb249f..0df3c274 100644 --- a/lazyllm/tools/__init__.py +++ b/lazyllm/tools/__init__.py @@ -1,4 +1,4 @@ -from .rag import Document, Reranker, Retriever, SentenceSplitter, LLMParser, StoreBase, IndexBase +from .rag import Document, Reranker, Retriever, SentenceSplitter, LLMParser from .webpages import WebModule from .agent import ( ToolManager, @@ -32,6 +32,4 @@ "SqlManager", "SqlCall", "HttpTool", - 'StoreBase', - 'IndexBase', ] diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index 17b25b16..dfd5da76 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -8,8 +8,8 @@ MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader) from .dataReader import SimpleDirectoryReader from .doc_manager import DocManager, DocListManager -from .store_base import StoreBase -from .index_base import IndexBase +from .store_base import EMBED_DEFAULT_KEY +from .milvus_store import MilvusField __all__ = [ @@ -39,6 +39,6 @@ "SimpleDirectoryReader", 'DocManager', 'DocListManager', - 'StoreBase', - 'IndexBase', + 'MilvusField', + 'EMBED_DEFAULT_KEY', ] diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index fe5bc0cf..f9087cec 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -3,9 +3,8 @@ from lazyllm import LOG from lazyllm.common import override from chromadb.api.models.Collection import Collection -from .store_base import StoreBase +from .store_base import StoreBase, LAZY_ROOT_NAME from .doc_node import DocNode -from .store import LAZY_ROOT_NAME from .index_base import IndexBase from .utils import _FileNodeIndex from .default_index import DefaultIndex @@ -15,9 +14,9 @@ # ---------------------------------------------------------------------------- # class ChromadbStore(StoreBase): - def __init__(self, node_groups: List[str], path: str, embed: Dict[str, Callable], - embed_dim: Dict[str, int]) -> None: - self._map_store = MapStore(node_groups) + def __init__(self, node_groups: List[str], path: str, embed_dim: Dict[str, int], + embed: Dict[str, Callable]) -> None: + self._map_store = MapStore(node_groups=node_groups, embed=embed) self._db_client = chromadb.PersistentClient(path=path) LOG.success(f"Initialzed chromadb in path: {path}") self._collections: Dict[str, Collection] = { diff --git a/lazyllm/tools/rag/data_loaders.py b/lazyllm/tools/rag/data_loaders.py index 0212fc17..f19cf5ce 100644 --- a/lazyllm/tools/rag/data_loaders.py +++ b/lazyllm/tools/rag/data_loaders.py @@ -1,6 +1,6 @@ from typing import List, Optional, Dict from .doc_node import DocNode -from .store import LAZY_ROOT_NAME +from .store_base import LAZY_ROOT_NAME from lazyllm import LOG from .dataReader import SimpleDirectoryReader diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 519d0ab2..f56e7e7c 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -5,10 +5,10 @@ from lazyllm import LOG, once_wrapper from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser, AdaptiveTransform, make_transform, TransformArgs) -from .store import LAZY_ROOT_NAME -from .store_base import StoreBase +from .store_base import StoreBase, LAZY_ROOT_NAME from .map_store import MapStore from .chroma_store import ChromadbStore +from .milvus_store import MilvusStore from .doc_node import DocNode from .data_loaders import DirectoryReader from .utils import DocListManager @@ -57,6 +57,14 @@ def _lazy_init(self) -> None: self._embed_dim = {k: len(e('a')) for k, e in self.embed.items()} + if self.store is None: + self.store = { + 'type': 'map', + 'kwargs': { + 'node_groups': self.node_groups, + }, + } + if isinstance(self.store, Dict): self.store = self._create_store(self.store) else: diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index f5a01214..346b67d5 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -8,7 +8,7 @@ from .doc_manager import DocManager from .doc_impl import DocImpl from .doc_node import DocNode -from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY +from .store_base import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY from .utils import DocListManager import copy import functools diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 60795f3a..d737fdeb 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -32,9 +32,8 @@ class MilvusStore(StoreBase): pymilvus.DataType.SPARSE_FLOAT_VECTOR, # DTYPE_SPARSE_FLOAT_VECTOR ] - def __init__(self, uri: str, embed: Dict[str, Callable], - # a field is either an embedding key or a metadata key - group_fields: Dict[str, List[MilvusField]]): + def __init__(self, uri: str, group_fields: Dict[str, List[MilvusField]], + embed: Dict[str, Callable]): self._primary_key = 'uid' self._embedding_keys = embed.keys() self._embed = embed @@ -81,7 +80,7 @@ def __init__(self, uri: str, embed: Dict[str, Callable], self._client.create_collection(collection_name=group_name, schema=schema, index_params=index_params) - self._map_store = MapStore(list(group_fields.keys())) + self._map_store = MapStore(node_groups=list(group_fields.keys()), embed=embed) self._load_all_nodes_to(self._map_store) @override From aaa7f67e7357b7d483413c2c8dcc97072f452fac Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 11:44:49 +0800 Subject: [PATCH 36/60] s --- lazyllm/tools/rag/chroma_store.py | 5 +++-- lazyllm/tools/rag/doc_impl.py | 10 ++++------ lazyllm/tools/rag/map_store.py | 2 +- lazyllm/tools/rag/milvus_store.py | 12 ++++++++++-- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index f9087cec..84e8b8d3 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -14,8 +14,9 @@ # ---------------------------------------------------------------------------- # class ChromadbStore(StoreBase): - def __init__(self, node_groups: List[str], path: str, embed_dim: Dict[str, int], - embed: Dict[str, Callable]) -> None: + def __init__(self, path: str, embed_dim: Dict[str, int], + node_groups: List[str], embed: Dict[str, Callable], + **kwargs) -> None: self._map_store = MapStore(node_groups=node_groups, embed=embed) self._db_client = chromadb.PersistentClient(path=path) LOG.success(f"Initialzed chromadb in path: {path}") diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index f56e7e7c..29011c0f 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -60,9 +60,7 @@ def _lazy_init(self) -> None: if self.store is None: self.store = { 'type': 'map', - 'kwargs': { - 'node_groups': self.node_groups, - }, + 'kwargs': {}, } if isinstance(self.store, Dict): @@ -93,11 +91,11 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase: raise ValueError('`kwargs` in store conf is not a dict.') if type == "map": - store = MapStore(embed=self.embed, **kwargs) + store = MapStore(embed=self.embed, node_groups=self.node_groups, **kwargs) elif type == "chroma": - store = ChromadbStore(embed=self.embed, **kwargs) + store = ChromadbStore(embed=self.embed, node_groups=self.node_groups, **kwargs) elif type == "milvus": - store = MilvusStore(embed=self.embed, **kwargs) + store = MilvusStore(embed=self.embed, node_groups=self.node_groups, **kwargs) else: raise NotImplementedError( f"Not implemented store type for {type}" diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index 447721aa..52e2d0b6 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -16,7 +16,7 @@ def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], index.remove(uids, group_name) class MapStore(StoreBase): - def __init__(self, node_groups: List[str], embed: Dict[str, Callable]): + def __init__(self, node_groups: List[str], embed: Dict[str, Callable], **kwargs): # Dict[group_name, Dict[uuid, DocNode]] self._group2docs: Dict[str, Dict[str, DocNode]] = { group: {} for group in node_groups diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index d737fdeb..4e119772 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -33,7 +33,14 @@ class MilvusStore(StoreBase): ] def __init__(self, uri: str, group_fields: Dict[str, List[MilvusField]], - embed: Dict[str, Callable]): + node_groups: List[str], embed: Dict[str, Callable], **kwargs): + new_copy = copy.copy(group_fields) + for g in node_groups: + if g not in new_copy: + new_copy[g] = [] + group_fields = new_copy + print(f'debug!!! uri -> {uri}, group_fields -> {group_fields}') + self._primary_key = 'uid' self._embedding_keys = embed.keys() self._embed = embed @@ -76,7 +83,7 @@ def __init__(self, uri: str, group_fields: Dict[str, List[MilvusField]], metric_type=field.metric_type, params=field.index_params) - schema = CollectionSchema(fields=field_schema_list) + schema = CollectionSchema(fields=field_schema_list, enable_dynamic_field=True) self._client.create_collection(collection_name=group_name, schema=schema, index_params=index_params) @@ -88,6 +95,7 @@ def update_nodes(self, nodes: List[DocNode]) -> None: parallel_do_embedding(self._embed, nodes) for node in nodes: data = self._serialize_node_partial(node) + print(f'debug!!! update group [{node.group}]') self._client.upsert(collection_name=node.group, data=[data]) self._map_store.update_nodes(nodes) From bfac3a8d88eda7f0b3ae357f87b95f07e207e6ca Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 11:45:03 +0800 Subject: [PATCH 37/60] s --- lazyllm/tools/rag/milvus_store.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 4e119772..1cbfaca6 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -39,7 +39,6 @@ def __init__(self, uri: str, group_fields: Dict[str, List[MilvusField]], if g not in new_copy: new_copy[g] = [] group_fields = new_copy - print(f'debug!!! uri -> {uri}, group_fields -> {group_fields}') self._primary_key = 'uid' self._embedding_keys = embed.keys() @@ -95,7 +94,6 @@ def update_nodes(self, nodes: List[DocNode]) -> None: parallel_do_embedding(self._embed, nodes) for node in nodes: data = self._serialize_node_partial(node) - print(f'debug!!! update group [{node.group}]') self._client.upsert(collection_name=node.group, data=[data]) self._map_store.update_nodes(nodes) From 97ac0f0cb27bf76a2ff999acd86c3680b9caff61 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 12:27:54 +0800 Subject: [PATCH 38/60] s --- lazyllm/tools/rag/doc_impl.py | 25 +++++++++++++++------- lazyllm/tools/rag/smart_embedding_index.py | 6 +++--- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 29011c0f..12cf8287 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -9,6 +9,7 @@ from .map_store import MapStore from .chroma_store import ChromadbStore from .milvus_store import MilvusStore +from .smart_embedding_index import SmartEmbeddingIndex from .doc_node import DocNode from .data_loaders import DirectoryReader from .utils import DocListManager @@ -60,7 +61,6 @@ def _lazy_init(self) -> None: if self.store is None: self.store = { 'type': 'map', - 'kwargs': {}, } if isinstance(self.store, Dict): @@ -82,25 +82,34 @@ def _lazy_init(self) -> None: self._daemon.start() def _create_store(self, store_conf: Optional[Dict]) -> StoreBase: - type = store_conf.get('type') - if not type: + store_type = store_conf.get('type') + if not store_type: raise ValueError('store type is not specified.') - kwargs = store_conf.get('kwargs') + kwargs = store_conf.get('kwargs', {}) if not isinstance(kwargs, Dict): raise ValueError('`kwargs` in store conf is not a dict.') - if type == "map": + if store_type == "map": store = MapStore(embed=self.embed, node_groups=self.node_groups, **kwargs) - elif type == "chroma": + elif store_type == "chroma": store = ChromadbStore(embed=self.embed, node_groups=self.node_groups, **kwargs) - elif type == "milvus": + elif store_type == "milvus": store = MilvusStore(embed=self.embed, node_groups=self.node_groups, **kwargs) else: raise NotImplementedError( - f"Not implemented store type for {type}" + f"Not implemented store type for {store_type}" ) + indices_conf = store_conf.get('indices', {}) + if not isinstance(indices_conf, Dict): + raise ValueError(f"`indices`'s type [{type(indices_conf)}] is not a dict") + + for backend_type, kwargs in indices_conf.items(): + index = SmartEmbeddingIndex(backend_type=backend_type, embed=self.embed, + node_groups=self.node_groups, **kwargs) + store.register_index(type=backend_type, index=index) + return store @staticmethod diff --git a/lazyllm/tools/rag/smart_embedding_index.py b/lazyllm/tools/rag/smart_embedding_index.py index c975991c..90dc6b0c 100644 --- a/lazyllm/tools/rag/smart_embedding_index.py +++ b/lazyllm/tools/rag/smart_embedding_index.py @@ -6,11 +6,11 @@ from .milvus_store import MilvusStore class SmartEmbeddingIndex(IndexBase): - def __init__(self, backend_type: str, fields: List[str], *args, **kwargs): + def __init__(self, backend_type: str, **kwargs): if backend_type == 'milvus': - self._store = MilvusStore(*args, **kwargs) + self._store = MilvusStore(**kwargs) elif backend_type == 'map': - self._store = MapStore(*args, **kwargs) + self._store = MapStore(**kwargs) else: raise ValueError(f'unsupported backend [{backend_type}]') From 0e9f11d738235446f64513e6f213c4b4e48bc519 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 14:51:14 +0800 Subject: [PATCH 39/60] s --- lazyllm/tools/rag/milvus_store.py | 67 ++++++++++++++++++------------- tests/basic_tests/test_store.py | 45 +++++++++++---------- 2 files changed, 62 insertions(+), 50 deletions(-) diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 1cbfaca6..0ce3c277 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -14,10 +14,9 @@ class MilvusField: DTYPE_FLOAT_VECTOR = 1 DTYPE_SPARSE_FLOAT_VECTOR = 2 - def __init__(self, name: str, data_type: int, index_type: Optional[str] = None, + def __init__(self, data_type: int, index_type: Optional[str] = None, metric_type: Optional[str] = "", index_params: Dict = {}, max_length: Optional[int] = None): - self.name = name self.data_type = data_type self.index_type = index_type self.metric_type = metric_type @@ -32,12 +31,12 @@ class MilvusStore(StoreBase): pymilvus.DataType.SPARSE_FLOAT_VECTOR, # DTYPE_SPARSE_FLOAT_VECTOR ] - def __init__(self, uri: str, group_fields: Dict[str, List[MilvusField]], + def __init__(self, uri: str, group_fields: Dict[str, Dict[str, MilvusField]], node_groups: List[str], embed: Dict[str, Callable], **kwargs): new_copy = copy.copy(group_fields) for g in node_groups: if g not in new_copy: - new_copy[g] = [] + new_copy[g] = {} group_fields = new_copy self._primary_key = 'uid' @@ -46,43 +45,53 @@ def __init__(self, uri: str, group_fields: Dict[str, List[MilvusField]], self._client = MilvusClient(uri=uri) embed_dim = {k: len(e('a')) for k, e in embed.items()} - builtin_fields = [ - FieldSchema(name=self._primary_key, dtype=pymilvus.DataType.VARCHAR, - max_length=128, is_primary=True), - FieldSchema(name='text', dtype=pymilvus.DataType.VARCHAR, - max_length=65535), - FieldSchema(name='parent', dtype=pymilvus.DataType.VARCHAR, - max_length=256), - ] - - for group_name, field_list in group_fields.items(): + builtin_fields = { + self._primary_key: { + 'datatype': pymilvus.DataType.VARCHAR, + 'max_length': 128, + 'is_primary': True, + }, + 'text': { + 'datatype': pymilvus.DataType.VARCHAR, + 'max_length': True, + }, + 'parent': { + 'datatype': pymilvus.DataType.VARCHAR, + 'max_length': 256, + }, + } + + for group_name, fields in group_fields.items(): if group_name in self._client.list_collections(): continue index_params = self._client.prepare_index_params() - field_schema_list = copy.copy(builtin_fields) - - for field in field_list: - field_schema = None - if field.name in self._embedding_keys: - field_schema = FieldSchema( - name=self._gen_embedding_key(field.name), - dtype=self._type2milvus[field.data_type], - dim=embed_dim.get(field.name)) + schema = self._client.create_schema(auto_id=False, enable_dynamic_field=True) + + for name, field in builtin_fields.items(): + schema.add_field(field_name=name, **field) + + for name, field in fields.items(): + field_name = None + if name in self._embedding_keys: + field_name = self._gen_embedding_key(name) + schema.add_field( + field_name=field_name, + datatype=self._type2milvus[field.data_type], + dim=embed_dim.get(name)) else: - field_schema = FieldSchema( - name=self._gen_metadata_key(field.name), - dtype=self._type2milvus[field.data_type], + field_name = self._gen_metadata_key(name) + schema.add_field( + field_name=field_name, + datatype=self._type2milvus[field.data_type], max_length=field.max_length) - field_schema_list.append(field_schema) if field.index_type is not None: - index_params.add_index(field_name=field_schema.name, + index_params.add_index(field_name=field_name, index_type=field.index_type, metric_type=field.metric_type, params=field.index_params) - schema = CollectionSchema(fields=field_schema_list, enable_dynamic_field=True) self._client.create_collection(collection_name=group_name, schema=schema, index_params=index_params) diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 20879c2c..15e8950f 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -4,7 +4,7 @@ import unittest from unittest.mock import MagicMock import lazyllm -from lazyllm.tools.rag.store import LAZY_ROOT_NAME +from lazyllm.tools.rag.store_base import LAZY_ROOT_NAME from lazyllm.tools.rag.map_store import MapStore from lazyllm.tools.rag.chroma_store import ChromadbStore from lazyllm.tools.rag.milvus_store import MilvusStore, MilvusField @@ -30,14 +30,18 @@ class TestChromadbStore(unittest.TestCase): def setUp(self): self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] self.embed_dim = {"default": 3} - self.store = ChromadbStore(self.node_groups, self.embed_dim) + self.store_dir = tempfile.mkdtemp() + self.mock_embed = { + 'default': MagicMock(return_value=[1.0, 2.0, 3.0]), + } + self.store = ChromadbStore(path=self.store_dir, node_groups=self.node_groups, + embed=self.mock_embed, embed_dim=self.embed_dim) self.store.update_nodes( [DocNode(uid="1", text="text1", group=LAZY_ROOT_NAME, parent=None)], ) - @classmethod - def tearDownClass(cls): - clear_directory(lazyllm.config['rag_persistent_path']) + def tearDown(self): + clear_directory(self.store_dir) def test_initialization(self): self.assertEqual(set(self.store._collections.keys()), set(self.node_groups)) @@ -108,7 +112,7 @@ def test_group_others(self): class TestMapStore(unittest.TestCase): def setUp(self): self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - self.store = MapStore(self.node_groups) + self.store = MapStore(node_groups=self.node_groups, embed={}) self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None) self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1) @@ -155,16 +159,16 @@ def test_group_others(self): class TestMilvusStore(unittest.TestCase): def setUp(self): - field_list = [ - MilvusField(name="comment", data_type=MilvusField.DTYPE_VARCHAR, max_length=128), - MilvusField(name="vec1", data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), - MilvusField(name="vec2", data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), - ] + fields = { + 'comment': MilvusField(data_type=MilvusField.DTYPE_VARCHAR, max_length=128), + 'vec1': MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), + 'vec2': MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), + } group_fields = { - "group1": field_list, - "group2": field_list, + "group1": fields, + "group2": fields, } self.mock_embed = { @@ -176,8 +180,7 @@ def setUp(self): _, self.store_file = tempfile.mkstemp(suffix=".db") self.store = MilvusStore(uri=self.store_file, embed=self.mock_embed, - group_fields=group_fields) - self.index = self.store.get_index() + node_groups=self.node_groups, group_fields=group_fields) self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, embedding={"vec1": [8.0, 9.0, 10.0], "vec2": [11.0, 12.0, 13.0, 14.0, 15.0]}, @@ -191,22 +194,22 @@ def tearDown(self): def test_update_and_query(self): self.store.update_nodes([self.node1]) - ret = self.index.query(query='text1', group_name='group1', embed_keys=['vec2'], topk=1) + ret = self.store.query(query='text1', group_name='group1', embed_keys=['vec2'], topk=1) self.assertEqual(len(ret), 1) self.assertEqual(ret[0].uid, self.node1.uid) self.store.update_nodes([self.node2]) - ret = self.index.query(query='text2', group_name='group1', embed_keys=['vec2'], topk=1) + ret = self.store.query(query='text2', group_name='group1', embed_keys=['vec2'], topk=1) self.assertEqual(len(ret), 1) self.assertEqual(ret[0].uid, self.node2.uid) def test_remove_and_query(self): self.store.update_nodes([self.node1, self.node2]) - ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) + ret = self.store.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) self.assertEqual(len(ret), 1) self.assertEqual(ret[0].uid, self.node2.uid) self.store.remove_nodes("group1", [self.node2.uid]) - ret = self.index.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) + ret = self.store.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) self.assertEqual(len(ret), 1) self.assertEqual(ret[0].uid, self.node1.uid) From 5b1f8c65e5bc2302d59cf74defc2b35096d275c3 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 14:53:26 +0800 Subject: [PATCH 40/60] s --- examples/rag_map_store_with_milvus_index.py | 66 ++++++++++++++++++++ examples/rag_milvus_store.py | 67 +++++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 examples/rag_map_store_with_milvus_index.py create mode 100644 examples/rag_milvus_store.py diff --git a/examples/rag_map_store_with_milvus_index.py b/examples/rag_map_store_with_milvus_index.py new file mode 100644 index 00000000..61832f08 --- /dev/null +++ b/examples/rag_map_store_with_milvus_index.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +import os +import lazyllm +import tempfile +from lazyllm.tools.rag import MilvusField, EMBED_DEFAULT_KEY + +_, store_file = tempfile.mkstemp(suffix=".db") + +fields = { + EMBED_DEFAULT_KEY: MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), +} + +milvus_store_conf = { + 'type': 'map', + 'indices': { + 'milvus': { + 'uri': store_file, + 'group_fields': { + 'sentences': fields, + }, + }, + }, +} + +documents = lazyllm.Document(dataset_path="rag_master", + embed=lazyllm.TrainableModule("bge-large-zh-v1.5"), + manager=False, + store_conf=milvus_store_conf) + +documents.create_node_group(name="sentences", + transform=lambda s: '。'.split(s)) + +prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task. In this task, you need to provide your answer based on the given context and question.' + +with lazyllm.pipeline() as ppl: + with lazyllm.parallel().sum as ppl.prl: + prl.retriever1 = lazyllm.Retriever(doc=documents, + group_name="CoarseChunk", + similarity="bm25_chinese", + topk=3) + prl.retriever2 = lazyllm.Retriever(doc=documents, + group_name="sentences", + similarity="cosine", + topk=3) + + ppl.reranker = lazyllm.Reranker(name='ModuleReranker', + model="bge-reranker-large", + topk=1, + output_format='content', + join=True) | bind(query=ppl.input) + + ppl.formatter = ( + lambda nodes, query: dict(context_str=nodes, query=query) + ) | bind(query=ppl.input) + + ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt( + lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str'])) + +rag = lazyllm.ActionModule(ppl) +rag.start() + +print("answer: ", rag('who are you?')) + +os.remove(store_file) diff --git a/examples/rag_milvus_store.py b/examples/rag_milvus_store.py new file mode 100644 index 00000000..bf6a3f1d --- /dev/null +++ b/examples/rag_milvus_store.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- + +import os +import lazyllm +import tempfile +from lazyllm.tools.rag import MilvusField, EMBED_DEFAULT_KEY + +_, store_file = tempfile.mkstemp(suffix=".db") + +fields = { + EMBED_DEFAULT_KEY: MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), +} + +milvus_store_conf = { + 'type': 'milvus', + 'kwargs': { + 'uri': store_file, + 'group_fields': { + 'sentences': fields, + }, + }, + 'indices': { + 'map': {} + }, +} + +documents = lazyllm.Document(dataset_path="rag_master", + embed=lazyllm.TrainableModule("bge-large-zh-v1.5"), + manager=False, + store_conf=milvus_store_conf) + +documents.create_node_group(name="sentences", + transform=lambda s: '。'.split(s)) + +prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task. In this task, you need to provide your answer based on the given context and question.' + +with lazyllm.pipeline() as ppl: + with lazyllm.parallel().sum as ppl.prl: + prl.retriever1 = lazyllm.Retriever(doc=documents, + group_name="CoarseChunk", + similarity="bm25_chinese", + topk=3) + prl.retriever2 = lazyllm.Retriever(doc=documents, + group_name="sentences", + similarity="cosine", + topk=3) + + ppl.reranker = lazyllm.Reranker(name='ModuleReranker', + model="bge-reranker-large", + topk=1, + output_format='content', + join=True) | bind(query=ppl.input) + + ppl.formatter = ( + lambda nodes, query: dict(context_str=nodes, query=query) + ) | bind(query=ppl.input) + + ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt( + lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str'])) + +rag = lazyllm.ActionModule(ppl) +rag.start() + +print("answer: ", rag('who are you?')) + +os.remove(store_file) From 073150702650f1235036a8934836d10a58bc4641 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 15:18:05 +0800 Subject: [PATCH 41/60] s --- tests/basic_tests/test_store.py | 34 ++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 15e8950f..da43ef23 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -99,9 +99,18 @@ def test_insert_dict_as_sparse_embedding(self): for uid, node in nodes_dict.items(): assert node.embedding['default'] == orig_embedding_dict.get(uid) - def test_group_names(self): + def test_all_groups(self): self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) + def test_query(self): + node1 = DocNode(uid="1", text="text1", group="group1", parent=None) + node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) + self.store.update_nodes([node1, node2]) + index = self.store.get_index() + res = self.store.query(query='text1', group_name='group1', embed_keys=['default'], topk=2, + similarity_name='cosine', similarity_cut_off=0.000001) + self.assertEqual(set([node1, node2]), set(res)) + def test_group_others(self): node1 = DocNode(uid="1", text="text1", group="group1", parent=None) node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) @@ -111,8 +120,12 @@ def test_group_others(self): class TestMapStore(unittest.TestCase): def setUp(self): + self.mock_embed = { + 'default': MagicMock(return_value=[1.0, 2.0, 3.0]), + } + self.embed_dim = {"default": 3} self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - self.store = MapStore(node_groups=self.node_groups, embed={}) + self.store = MapStore(node_groups=self.node_groups, embed=self.mock_embed) self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None) self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1) @@ -149,9 +162,16 @@ def test_remove_group_nodes(self): n2 = self.store.get_nodes("group1", ["2"]) assert not n2 - def test_group_names(self): + def test_all_groups(self): self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) + def test_query(self): + self.store.update_nodes([self.node1, self.node2]) + index = self.store.get_index() + res = self.store.query(query='text1', group_name='group1', embed_keys=['default'], topk=2, + similarity_name='cosine', similarity_cut_off=0.000001) + self.assertEqual(set([self.node1, self.node2]), set(res)) + def test_group_others(self): self.store.update_nodes([self.node1, self.node2]) self.assertEqual(self.store.is_group_active("group1"), True) @@ -213,3 +233,11 @@ def test_remove_and_query(self): ret = self.store.query(query='test', group_name='group1', embed_keys=['vec2'], topk=1) self.assertEqual(len(ret), 1) self.assertEqual(ret[0].uid, self.node1.uid) + + def test_all_groups(self): + self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) + + def test_group_others(self): + self.store.update_nodes([self.node1, self.node2]) + self.assertEqual(self.store.is_group_active("group1"), True) + self.assertEqual(self.store.is_group_active("group2"), False) From 7e2aab2269c2d45f7fe41d575fd131e97762364a Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 15:20:12 +0800 Subject: [PATCH 42/60] s --- lazyllm/tools/rag/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index dfd5da76..62531b73 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -8,7 +8,6 @@ MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader) from .dataReader import SimpleDirectoryReader from .doc_manager import DocManager, DocListManager -from .store_base import EMBED_DEFAULT_KEY from .milvus_store import MilvusField @@ -40,5 +39,4 @@ 'DocManager', 'DocListManager', 'MilvusField', - 'EMBED_DEFAULT_KEY', ] From 91f8d4d41fabc6cecd0c87b55d62ab63b23ed087 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 15:21:20 +0800 Subject: [PATCH 43/60] s --- lazyllm/tools/rag/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index 62531b73..dfd5da76 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -8,6 +8,7 @@ MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader) from .dataReader import SimpleDirectoryReader from .doc_manager import DocManager, DocListManager +from .store_base import EMBED_DEFAULT_KEY from .milvus_store import MilvusField @@ -39,4 +40,5 @@ 'DocManager', 'DocListManager', 'MilvusField', + 'EMBED_DEFAULT_KEY', ] From 947ee9273c36e25357390358d7d0ebd291764cdb Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 15:27:42 +0800 Subject: [PATCH 44/60] s --- examples/rag_map_store_with_milvus_index.py | 13 +++++++------ examples/rag_milvus_store.py | 11 ++++++----- lazyllm/tools/rag/milvus_store.py | 2 +- tests/basic_tests/test_store.py | 3 --- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/examples/rag_map_store_with_milvus_index.py b/examples/rag_map_store_with_milvus_index.py index 61832f08..4e5c1c10 100644 --- a/examples/rag_map_store_with_milvus_index.py +++ b/examples/rag_map_store_with_milvus_index.py @@ -32,15 +32,16 @@ documents.create_node_group(name="sentences", transform=lambda s: '。'.split(s)) -prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task. In this task, you need to provide your answer based on the given context and question.' +prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task.'\ + ' In this task, you need to provide your answer based on the given context and question.' with lazyllm.pipeline() as ppl: - with lazyllm.parallel().sum as ppl.prl: - prl.retriever1 = lazyllm.Retriever(doc=documents, + with lazyllm.parallel().sum as ppl.prl: # noqa F821 + prl.retriever1 = lazyllm.Retriever(doc=documents, # noqa F821 group_name="CoarseChunk", similarity="bm25_chinese", topk=3) - prl.retriever2 = lazyllm.Retriever(doc=documents, + prl.retriever2 = lazyllm.Retriever(doc=documents, # noqa F821 group_name="sentences", similarity="cosine", topk=3) @@ -49,11 +50,11 @@ model="bge-reranker-large", topk=1, output_format='content', - join=True) | bind(query=ppl.input) + join=True) | bind(query=ppl.input) # noqa F821 ppl.formatter = ( lambda nodes, query: dict(context_str=nodes, query=query) - ) | bind(query=ppl.input) + ) | bind(query=ppl.input) # noqa F821 ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt( lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str'])) diff --git a/examples/rag_milvus_store.py b/examples/rag_milvus_store.py index bf6a3f1d..fa70f0f6 100644 --- a/examples/rag_milvus_store.py +++ b/examples/rag_milvus_store.py @@ -33,15 +33,16 @@ documents.create_node_group(name="sentences", transform=lambda s: '。'.split(s)) -prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task. In this task, you need to provide your answer based on the given context and question.' +prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task.'\ + ' In this task, you need to provide your answer based on the given context and question.' with lazyllm.pipeline() as ppl: with lazyllm.parallel().sum as ppl.prl: - prl.retriever1 = lazyllm.Retriever(doc=documents, + prl.retriever1 = lazyllm.Retriever(doc=documents, # noqa F821 group_name="CoarseChunk", similarity="bm25_chinese", topk=3) - prl.retriever2 = lazyllm.Retriever(doc=documents, + prl.retriever2 = lazyllm.Retriever(doc=documents, # noqa F821 group_name="sentences", similarity="cosine", topk=3) @@ -50,11 +51,11 @@ model="bge-reranker-large", topk=1, output_format='content', - join=True) | bind(query=ppl.input) + join=True) | bind(query=ppl.input) # noqa F821 ppl.formatter = ( lambda nodes, query: dict(context_str=nodes, query=query) - ) | bind(query=ppl.input) + ) | bind(query=ppl.input) # noqa F821 ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt( lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str'])) diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 0ce3c277..d9dfb9f6 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -1,7 +1,7 @@ import copy from typing import Dict, List, Optional, Union, Callable import pymilvus -from pymilvus import MilvusClient, FieldSchema, CollectionSchema +from pymilvus import MilvusClient from .doc_node import DocNode from .map_store import MapStore from .embed_utils import parallel_do_embedding diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index da43ef23..0f65edc4 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -3,7 +3,6 @@ import tempfile import unittest from unittest.mock import MagicMock -import lazyllm from lazyllm.tools.rag.store_base import LAZY_ROOT_NAME from lazyllm.tools.rag.map_store import MapStore from lazyllm.tools.rag.chroma_store import ChromadbStore @@ -106,7 +105,6 @@ def test_query(self): node1 = DocNode(uid="1", text="text1", group="group1", parent=None) node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) self.store.update_nodes([node1, node2]) - index = self.store.get_index() res = self.store.query(query='text1', group_name='group1', embed_keys=['default'], topk=2, similarity_name='cosine', similarity_cut_off=0.000001) self.assertEqual(set([node1, node2]), set(res)) @@ -167,7 +165,6 @@ def test_all_groups(self): def test_query(self): self.store.update_nodes([self.node1, self.node2]) - index = self.store.get_index() res = self.store.query(query='text1', group_name='group1', embed_keys=['default'], topk=2, similarity_name='cosine', similarity_cut_off=0.000001) self.assertEqual(set([self.node1, self.node2]), set(res)) From 0edd72548e841f89145d7e0adf919c8bdb130546 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 15:47:40 +0800 Subject: [PATCH 45/60] s --- lazyllm/tools/rag/default_index.py | 2 +- lazyllm/tools/rag/embed_utils.py | 30 ------------------------------ lazyllm/tools/rag/milvus_store.py | 2 +- lazyllm/tools/rag/utils.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 30 insertions(+), 32 deletions(-) diff --git a/lazyllm/tools/rag/default_index.py b/lazyllm/tools/rag/default_index.py index 225983ea..fcd4fbfa 100644 --- a/lazyllm/tools/rag/default_index.py +++ b/lazyllm/tools/rag/default_index.py @@ -6,7 +6,7 @@ from .component.bm25 import BM25 from lazyllm import LOG from lazyllm.common import override -from .embed_utils import parallel_do_embedding +from .utils import parallel_do_embedding # ---------------------------------------------------------------------------- # diff --git a/lazyllm/tools/rag/embed_utils.py b/lazyllm/tools/rag/embed_utils.py index 008accbb..f0a6250d 100644 --- a/lazyllm/tools/rag/embed_utils.py +++ b/lazyllm/tools/rag/embed_utils.py @@ -3,33 +3,3 @@ from typing import Dict, Callable, List from lazyllm import config, ThreadPoolExecutor from .doc_node import DocNode - -# min(32, (os.cpu_count() or 1) + 4) is the default number of workers for ThreadPoolExecutor -config.add( - "max_embedding_workers", - int, - min(32, (os.cpu_count() or 1) + 4), - "MAX_EMBEDDING_WORKERS", -) - -# returns a list of modified nodes -def parallel_do_embedding(embed: Dict[str, Callable], nodes: List[DocNode]) -> List[DocNode]: - modified_nodes = [] - with ThreadPoolExecutor(config["max_embedding_workers"]) as executor: - futures = [] - for node in nodes: - miss_keys = node.has_missing_embedding(embed.keys()) - if not miss_keys: - continue - modified_nodes.append(node) - for k in miss_keys: - with node._lock: - if node.has_missing_embedding(k): - future = executor.submit(node.do_embedding, {k: embed[k]}) \ - if k not in node._embedding_state else executor.submit(node.check_embedding_state, k) - node._embedding_state.add(k) - futures.append(future) - if len(futures) > 0: - for future in concurrent.futures.as_completed(futures): - future.result() - return modified_nodes diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index d9dfb9f6..50a9dfc5 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -4,7 +4,7 @@ from pymilvus import MilvusClient from .doc_node import DocNode from .map_store import MapStore -from .embed_utils import parallel_do_embedding +from .utils import parallel_do_embedding from .index_base import IndexBase from .store_base import StoreBase from lazyllm.common import override diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index 2a8d3079..15c80727 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -19,6 +19,13 @@ import lazyllm from lazyllm import config +# min(32, (os.cpu_count() or 1) + 4) is the default number of workers for ThreadPoolExecutor +config.add( + "max_embedding_workers", + int, + min(32, (os.cpu_count() or 1) + 4), + "MAX_EMBEDDING_WORKERS", +) config.add("default_dlmanager", str, "sqlite", "DEFAULT_DOCLIST_MANAGER") @@ -477,6 +484,27 @@ def save_files_in_threads( shutil.rmtree(cache_dir) return (already_exist_files, new_add_files, overwritten_files) +# returns a list of modified nodes +def parallel_do_embedding(embed: Dict[str, Callable], nodes: List[DocNode]) -> List[DocNode]: + modified_nodes = [] + with ThreadPoolExecutor(config["max_embedding_workers"]) as executor: + futures = [] + for node in nodes: + miss_keys = node.has_missing_embedding(embed.keys()) + if not miss_keys: + continue + modified_nodes.append(node) + for k in miss_keys: + with node._lock: + if node.has_missing_embedding(k): + future = executor.submit(node.do_embedding, {k: embed[k]}) \ + if k not in node._embedding_state else executor.submit(node.check_embedding_state, k) + node._embedding_state.add(k) + futures.append(future) + if len(futures) > 0: + for future in concurrent.futures.as_completed(futures): + future.result() + return modified_nodes class _FileNodeIndex(IndexBase): def __init__(self): From a6b1abbd632e75e04ecb4e5827ab0266ae945b4f Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 15:47:57 +0800 Subject: [PATCH 46/60] s --- tests/basic_tests/test_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index 3cd29447..526224ac 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -4,7 +4,7 @@ from lazyllm.tools.rag.map_store import MapStore from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag.default_index import DefaultIndex, register_similarity -from lazyllm.tools.rag.embed_utils import parallel_do_embedding +from lazyllm.tools.rag.utils import parallel_do_embedding class TestDefaultIndex(unittest.TestCase): def setUp(self): From cb048081d5acd79450590c6aa9603c804af0f67f Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 15:52:27 +0800 Subject: [PATCH 47/60] s --- tests/basic_tests/test_index.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index 526224ac..107d1bc6 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -8,16 +8,15 @@ class TestDefaultIndex(unittest.TestCase): def setUp(self): - self.mock_embed = MagicMock(side_effect=self.delayed_embed) - self.mock_embed1 = MagicMock(return_value=[0, 1, 0]) - self.mock_embed2 = MagicMock(return_value=[0, 0, 1]) - self.mock_store = MapStore(node_groups=['group1']) + self.mock_embed = { + 'default': MagicMock(side_effect=self.delayed_embed), + 'test1': MagicMock(return_value=[0, 1, 0]), + 'test2': MagicMock(return_value=[0, 0, 1]), + } + self.mock_store = MapStore(node_groups=['group1'], embed=self.mock_embed) # Create instance of DefaultIndex - self.index = DefaultIndex(embed={"default": self.mock_embed, - "test1": self.mock_embed1, - "test2": self.mock_embed2}, - store=self.mock_store) + self.index = DefaultIndex(embed=self.mock_embed, store=self.mock_store) # Create mock DocNodes self.doc_node_1 = DocNode(uid="text1", group="group1") From f895488920d24d5ce0308b37ea92ba3bf272b97b Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 17:40:01 +0800 Subject: [PATCH 48/60] s --- lazyllm/tools/rag/embed_utils.py | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 lazyllm/tools/rag/embed_utils.py diff --git a/lazyllm/tools/rag/embed_utils.py b/lazyllm/tools/rag/embed_utils.py deleted file mode 100644 index f0a6250d..00000000 --- a/lazyllm/tools/rag/embed_utils.py +++ /dev/null @@ -1,5 +0,0 @@ -import os -import concurrent -from typing import Dict, Callable, List -from lazyllm import config, ThreadPoolExecutor -from .doc_node import DocNode From a8ac8e5b70d91a6b26d3721f4738cd635379f202 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 19:48:07 +0800 Subject: [PATCH 49/60] s --- lazyllm/tools/rag/map_store.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index 52e2d0b6..31fa36d8 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -52,9 +52,12 @@ def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: if not docs: return [] - if not uids: + if uids is None: return list(docs.values()) + if len(uids) == 0: + return [] + ret = [] for uid in uids: doc = docs.get(uid) From ed2a66cce61c514640789633ed233f342a4db736 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 20:08:01 +0800 Subject: [PATCH 50/60] s --- tests/basic_tests/test_document.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index 9ce0d95f..2ccc8b1f 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -1,7 +1,8 @@ import lazyllm -from lazyllm.tools.rag.doc_impl import DocImpl, _FileNodeIndex +from lazyllm.tools.rag.doc_impl import DocImpl +from .utils import _FileNodeIndex from lazyllm.tools.rag.transform import SentenceSplitter -from lazyllm.tools.rag.store import LAZY_ROOT_NAME +from lazyllm.tools.rag.store_base import LAZY_ROOT_NAME from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag import Document, Retriever, TransformArgs, AdaptiveTransform from lazyllm.launcher import cleanup From c6e36da877ffdd9ba40919f2f050990a62e5e917 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 20:12:07 +0800 Subject: [PATCH 51/60] s --- tests/basic_tests/test_document.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index 2ccc8b1f..0ba38050 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -1,6 +1,6 @@ import lazyllm from lazyllm.tools.rag.doc_impl import DocImpl -from .utils import _FileNodeIndex +from lazyllm.tools.rag.utils import _FileNodeIndex from lazyllm.tools.rag.transform import SentenceSplitter from lazyllm.tools.rag.store_base import LAZY_ROOT_NAME from lazyllm.tools.rag.doc_node import DocNode From 5ae24748a5802e26150241ff831f7c42dcc7a73c Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 20:43:56 +0800 Subject: [PATCH 52/60] s --- examples/rag_map_store_with_milvus_index.py | 7 +++-- examples/rag_milvus_store.py | 7 +++-- lazyllm/tools/rag/milvus_store.py | 35 ++++++++++++--------- tests/basic_tests/test_store.py | 14 ++++++--- 4 files changed, 39 insertions(+), 24 deletions(-) diff --git a/examples/rag_map_store_with_milvus_index.py b/examples/rag_map_store_with_milvus_index.py index 4e5c1c10..0f6ed260 100644 --- a/examples/rag_map_store_with_milvus_index.py +++ b/examples/rag_map_store_with_milvus_index.py @@ -8,8 +8,11 @@ _, store_file = tempfile.mkstemp(suffix=".db") fields = { - EMBED_DEFAULT_KEY: MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), + 'embedding': { + EMBED_DEFAULT_KEY: MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), + }, + 'metadata': {} } milvus_store_conf = { diff --git a/examples/rag_milvus_store.py b/examples/rag_milvus_store.py index fa70f0f6..b05119fb 100644 --- a/examples/rag_milvus_store.py +++ b/examples/rag_milvus_store.py @@ -8,8 +8,11 @@ _, store_file = tempfile.mkstemp(suffix=".db") fields = { - EMBED_DEFAULT_KEY: MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), + 'embedding': { + EMBED_DEFAULT_KEY: MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), + }, + 'metadata': {} } milvus_store_conf = { diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 50a9dfc5..86b32d81 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -31,7 +31,7 @@ class MilvusStore(StoreBase): pymilvus.DataType.SPARSE_FLOAT_VECTOR, # DTYPE_SPARSE_FLOAT_VECTOR ] - def __init__(self, uri: str, group_fields: Dict[str, Dict[str, MilvusField]], + def __init__(self, uri: str, group_fields: Dict[str, Dict[str, Dict[str, MilvusField]]], node_groups: List[str], embed: Dict[str, Callable], **kwargs): new_copy = copy.copy(group_fields) for g in node_groups: @@ -71,21 +71,26 @@ def __init__(self, uri: str, group_fields: Dict[str, Dict[str, MilvusField]], for name, field in builtin_fields.items(): schema.add_field(field_name=name, **field) - for name, field in fields.items(): - field_name = None - if name in self._embedding_keys: - field_name = self._gen_embedding_key(name) - schema.add_field( - field_name=field_name, - datatype=self._type2milvus[field.data_type], - dim=embed_dim.get(name)) - else: - field_name = self._gen_metadata_key(name) - schema.add_field( - field_name=field_name, - datatype=self._type2milvus[field.data_type], - max_length=field.max_length) + embedding_fields = fields.get('embedding', {}) + for name, field in embedding_fields.items(): + field_name = self._gen_embedding_key(name) + schema.add_field( + field_name=field_name, + datatype=self._type2milvus[field.data_type], + dim=embed_dim.get(name)) + if field.index_type is not None: + index_params.add_index(field_name=field_name, + index_type=field.index_type, + metric_type=field.metric_type, + params=field.index_params) + metadata_fields = fields.get('metadata', {}) + for name, field in metadata_fields.items(): + field_name = self._gen_metadata_key(name) + schema.add_field( + field_name=field_name, + datatype=self._type2milvus[field.data_type], + max_length=field.max_length) if field.index_type is not None: index_params.add_index(field_name=field_name, index_type=field.index_type, diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 0f65edc4..6136b81c 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -177,11 +177,15 @@ def test_group_others(self): class TestMilvusStore(unittest.TestCase): def setUp(self): fields = { - 'comment': MilvusField(data_type=MilvusField.DTYPE_VARCHAR, max_length=128), - 'vec1': MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), - 'vec2': MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), + 'embedding': { + 'vec1': MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), + 'vec2': MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, + index_type='HNSW', metric_type='COSINE'), + }, + 'metadata': { + 'comment': MilvusField(data_type=MilvusField.DTYPE_VARCHAR, max_length=128), + }, } group_fields = { "group1": fields, From 38f873f476847f0da173c8886cbd68b1ac7982d3 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Mon, 4 Nov 2024 23:00:58 +0800 Subject: [PATCH 53/60] s --- examples/rag_milvus_store.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/rag_milvus_store.py b/examples/rag_milvus_store.py index b05119fb..ecaa57ba 100644 --- a/examples/rag_milvus_store.py +++ b/examples/rag_milvus_store.py @@ -23,9 +23,6 @@ 'sentences': fields, }, }, - 'indices': { - 'map': {} - }, } documents = lazyllm.Document(dataset_path="rag_master", From 4fa60dc4f584ff705e85fef4239bae849e56d694 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 5 Nov 2024 11:40:19 +0800 Subject: [PATCH 54/60] review begins --- examples/rag_map_store_with_milvus_index.py | 14 +- examples/rag_milvus_store.py | 14 +- lazyllm/tools/rag/__init__.py | 4 - lazyllm/tools/rag/chroma_store.py | 48 +++-- lazyllm/tools/rag/doc_field_info.py | 6 + lazyllm/tools/rag/doc_impl.py | 16 +- lazyllm/tools/rag/doc_node.py | 11 +- lazyllm/tools/rag/document.py | 10 +- lazyllm/tools/rag/map_store.py | 10 +- lazyllm/tools/rag/milvus_store.py | 192 +++++++++----------- lazyllm/tools/rag/store_base.py | 4 + lazyllm/tools/rag/utils.py | 1 + tests/basic_tests/test_store.py | 37 ++-- 13 files changed, 173 insertions(+), 194 deletions(-) create mode 100644 lazyllm/tools/rag/doc_field_info.py diff --git a/examples/rag_map_store_with_milvus_index.py b/examples/rag_map_store_with_milvus_index.py index 0f6ed260..b12be866 100644 --- a/examples/rag_map_store_with_milvus_index.py +++ b/examples/rag_map_store_with_milvus_index.py @@ -3,26 +3,16 @@ import os import lazyllm import tempfile -from lazyllm.tools.rag import MilvusField, EMBED_DEFAULT_KEY _, store_file = tempfile.mkstemp(suffix=".db") -fields = { - 'embedding': { - EMBED_DEFAULT_KEY: MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), - }, - 'metadata': {} -} - milvus_store_conf = { 'type': 'map', 'indices': { 'milvus': { 'uri': store_file, - 'group_fields': { - 'sentences': fields, - }, + 'embedding_index_type': 'HNSW', + 'embedding_metric_type': 'COSINE', }, }, } diff --git a/examples/rag_milvus_store.py b/examples/rag_milvus_store.py index ecaa57ba..c07488b3 100644 --- a/examples/rag_milvus_store.py +++ b/examples/rag_milvus_store.py @@ -3,25 +3,15 @@ import os import lazyllm import tempfile -from lazyllm.tools.rag import MilvusField, EMBED_DEFAULT_KEY _, store_file = tempfile.mkstemp(suffix=".db") -fields = { - 'embedding': { - EMBED_DEFAULT_KEY: MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), - }, - 'metadata': {} -} - milvus_store_conf = { 'type': 'milvus', 'kwargs': { 'uri': store_file, - 'group_fields': { - 'sentences': fields, - }, + 'embedding_index_type': 'HNSW', + 'embedding_metric_type': 'COSINE', }, } diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index dfd5da76..1c4c39ae 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -8,8 +8,6 @@ MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader) from .dataReader import SimpleDirectoryReader from .doc_manager import DocManager, DocListManager -from .store_base import EMBED_DEFAULT_KEY -from .milvus_store import MilvusField __all__ = [ @@ -39,6 +37,4 @@ "SimpleDirectoryReader", 'DocManager', 'DocListManager', - 'MilvusField', - 'EMBED_DEFAULT_KEY', ] diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index 84e8b8d3..c6d57c5a 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -8,29 +8,25 @@ from .index_base import IndexBase from .utils import _FileNodeIndex from .default_index import DefaultIndex -import json from .map_store import MapStore +import pickle # ---------------------------------------------------------------------------- # class ChromadbStore(StoreBase): - def __init__(self, path: str, embed_dim: Dict[str, int], - node_groups: List[str], embed: Dict[str, Callable], - **kwargs) -> None: - self._map_store = MapStore(node_groups=node_groups, embed=embed) - self._db_client = chromadb.PersistentClient(path=path) - LOG.success(f"Initialzed chromadb in path: {path}") - self._collections: Dict[str, Collection] = { - group: self._db_client.get_or_create_collection(group) - for group in node_groups - } - self._embed_dim = embed_dim + def __init__(self, dir: str, embed: Dict[str, Callable], embed_dim: Dict[str, int], **kwargs) -> None: + self._db_client = chromadb.PersistentClient(path=dir) + LOG.success(f"Initialzed chromadb in path: {dir}") + self._collections: Dict[str, Collection] = {} self._name2index = { 'default': DefaultIndex(embed, self._map_store), 'file_node_map': _FileNodeIndex(), } + self._map_store = MapStore(embed=embed) + self._load_store(embed_dim) + @override def update_nodes(self, nodes: List[DocNode]) -> None: self._map_store.update_nodes(nodes) @@ -56,6 +52,11 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: return self._map_store.all_groups() + @override + def add_group(self, name: str, embed_keys: Optional[List[str]] = None) -> None: + self._collections[name] = self._db_client.get_or_create_collection(name) + self._map_store.add_group(name, embed_keys) + @override def query(self, *args, **kwargs) -> List[DocNode]: return self.get_index('default').query(*args, **kwargs) @@ -70,7 +71,7 @@ def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]: type = 'default' return self._name2index.get(type) - def _load_store(self) -> None: + def _load_store(self, embed_dim: Dict[str, int]) -> None: if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: LOG.info("No persistent data found, skip the rebuilding phrase.") return @@ -78,7 +79,7 @@ def _load_store(self) -> None: # Restore all nodes for group in self._collections.keys(): results = self._peek_all_documents(group) - nodes = self._build_nodes_from_chroma(results) + nodes = self._build_nodes_from_chroma(results, embed_dim) self._map_store.update_nodes(nodes) # Rebuild relationships @@ -107,7 +108,6 @@ def _save_nodes(self, nodes: List[DocNode]) -> None: if node.is_saved: continue metadata = self._make_chroma_metadata(node) - metadata["embedding"] = json.dumps(node.embedding) ids.append(node.uid) embeddings.append([0]) # we don't use chroma for retrieving metadatas.append(metadata) @@ -127,16 +127,21 @@ def _delete_group_nodes(self, group_name: str, uids: List[str]) -> None: if collection: collection.delete(ids=uids) - def _build_nodes_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: + def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dim: Dict[str, int]) -> List[DocNode]: nodes: List[DocNode] = [] for i, uid in enumerate(results['ids']): chroma_metadata = results['metadatas'][i] + + parent = chroma_metadata['parent'] + fields = pickle.loads(chroma_metadata['fields']) if parent else None + node = DocNode( uid=uid, text=results["documents"][i], group=chroma_metadata["group"], - embedding=json.loads(chroma_metadata['embedding']), - parent=chroma_metadata["parent"], + embedding=pickle.loads(chroma_metadata['embedding']), + parent=parent, + fields=fields, ) if node.embedding: @@ -144,7 +149,7 @@ def _build_nodes_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: new_embedding_dict = {} for key, embedding in node.embedding.items(): if isinstance(embedding, dict): - dim = self._embed_dim.get(key) + dim = embed_dim.get(key) if not dim: raise ValueError(f'dim of embed [{key}] is not determined.') new_embedding = [0] * dim @@ -163,7 +168,12 @@ def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]: metadata = { "group": node.group, "parent": node.parent.uid if node.parent else "", + "embedding": pickle.dumps(node.embedding), } + + if node.parent: + metadata["fields"] = pickle.dumps(node.fields) + return metadata def _peek_all_documents(self, group: str) -> Dict[str, List]: diff --git a/lazyllm/tools/rag/doc_field_info.py b/lazyllm/tools/rag/doc_field_info.py new file mode 100644 index 00000000..3d233000 --- /dev/null +++ b/lazyllm/tools/rag/doc_field_info.py @@ -0,0 +1,6 @@ +class DocFieldInfo: + DTYPE_UNKNOWN = 0 + DTYPE_VARCHAR = 1 + + def __init__(self, data_type: DTYPE_UNKNOWN): + self.data_type = data_type diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 12cf8287..6b00ebeb 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -13,6 +13,7 @@ from .doc_node import DocNode from .data_loaders import DirectoryReader from .utils import DocListManager +from .doc_field_info import DocFieldInfo import threading import time @@ -37,7 +38,7 @@ class DocImpl: def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = None, doc_files: Optional[str] = None, kb_group_name: Optional[str] = None, - store_conf: Optional[Dict] = None): + fields_info: Dict[str, DocFieldInfo] = None, store_conf: Optional[Dict] = None): super().__init__() assert (dlm is None) ^ (doc_files is None), 'Only one of dataset_path or doc_files should be provided' self._local_file_reader: Dict[str, Callable] = {} @@ -47,6 +48,7 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N self.node_groups: Dict[str, Dict] = {LAZY_ROOT_NAME: {}} self.embed = {k: embed_wrapper(e) for k, e in embed.items()} self._embed_dim = None + self._fields_info = fields_info self.store = store_conf # NOTE: will be initialized in _lazy_init() @once_wrapper(reset_on_pickle=True) @@ -69,6 +71,7 @@ def _lazy_init(self) -> None: raise ValueError(f'store type [{type(self.store)}] is not a dict.') if not self.store.is_group_active(LAZY_ROOT_NAME): + self.store.add_group(name=LAZY_ROOT_NAME, fields_info=self._fields_info, embed_keys={}) ids, pathes = self._list_files() root_nodes = self._reader.load_data(pathes) self.store.update_nodes(root_nodes) @@ -91,11 +94,11 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase: raise ValueError('`kwargs` in store conf is not a dict.') if store_type == "map": - store = MapStore(embed=self.embed, node_groups=self.node_groups, **kwargs) + store = MapStore(embed=self.embed, **kwargs) elif store_type == "chroma": - store = ChromadbStore(embed=self.embed, node_groups=self.node_groups, **kwargs) + store = ChromadbStore(embed_dim=self.embed_dim, embed=self.embed, **kwargs) elif store_type == "milvus": - store = MilvusStore(embed=self.embed, node_groups=self.node_groups, **kwargs) + store = MilvusStore(embed=self.embed, fields_info=self._fields_info, **kwargs) else: raise NotImplementedError( f"Not implemented store type for {store_type}" @@ -212,7 +215,6 @@ def _list_files(self, status: Union[str, List[str]] = DocListManager.Status.all, def _add_files(self, input_files: List[str]): if len(input_files) == 0: return - self._lazy_init() root_nodes = self._reader.load_data(input_files) temp_store = self._create_store("map") temp_store.update_nodes(root_nodes) @@ -227,7 +229,6 @@ def _add_files(self, input_files: List[str]): LOG.debug(f"Merge {group} with {nodes}") def _delete_files(self, input_files: List[str]) -> None: - self._lazy_init() docs = self.store.get_index(type='file_node_map').query(input_files) LOG.info(f"delete_files: removing documents {input_files} and nodes {docs}") if len(docs) == 0: @@ -276,6 +277,9 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_ index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]: self._lazy_init() + if not self.store.is_group_active(group_name): + self.store.add_group(group_name, embed_keys=embed_keys) + if type is None or type == 'default': return self.store.query(query=query, group_name=group_name, similarity_name=similarity, similarity_cut_off=similarity_cut_off, topk=topk, diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 929f5570..2510722f 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -17,7 +17,7 @@ class MetadataMode(str, Enum): class DocNode: def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: Optional[str] = None, embedding: Optional[Dict[str, List[float]]] = None, parent: Optional["DocNode"] = None, - metadata: Optional[Dict[str, Any]] = None): + metadata: Optional[Dict[str, Any]] = None, fields: Optional[Dict[str, Any]] = None): self.uid: str = uid if uid else str(uuid.uuid4()) self.text: Optional[str] = text self.group: Optional[str] = group @@ -34,6 +34,10 @@ def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: self._lock = threading.Lock() self._embedding_state = set() + if fields and parent: + raise ValueError('only ROOT node can set fields.') + self._fields = fields + @property def root_node(self) -> Optional["DocNode"]: root = self.parent @@ -41,12 +45,17 @@ def root_node(self) -> Optional["DocNode"]: root = root.parent return root or self + @property + def fields(self) -> Dict[str, Any]: + return self.root_node._fields + @property def metadata(self) -> Dict: return self.root_node._metadata @metadata.setter def metadata(self, metadata: Dict) -> None: + self.is_saved = False self._metadata = metadata @property diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index 346b67d5..2030362c 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -10,6 +10,7 @@ from .doc_node import DocNode from .store_base import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY from .utils import DocListManager +from .doc_field_info import DocFieldInfo import copy import functools @@ -22,7 +23,8 @@ class Document(ModuleBase): class _Impl(ModuleBase): def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, manager: bool = False, server: bool = False, name: Optional[str] = None, - launcher: Optional[Launcher] = None, store_conf: Optional[Dict] = None): + launcher: Optional[Launcher] = None, store_conf: Optional[Dict] = None, + fields_info: Optional[Dict[str, DocFieldInfo]] = None): super().__init__() if not os.path.exists(dataset_path): defatult_path = os.path.join(lazyllm.config["data_path"], dataset_path) @@ -40,6 +42,7 @@ def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, DocImpl(embed=self._embed, dlm=self._dlm, store_conf=store_conf)}) if manager: self._manager = ServerModule(DocManager(self._dlm)) if server: self._kbs = ServerModule(self._kbs) + self._fields_info = fields_info def add_kb_group(self, name, store_conf: Optional[Dict] = None): if isinstance(self._kbs, ServerModule): @@ -61,11 +64,12 @@ def __call__(self, *args, **kw): def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, create_ui: bool = False, manager: bool = False, server: bool = False, name: Optional[str] = None, launcher: Optional[Launcher] = None, - store_conf: Optional[Dict] = None): + fields_info: Dict[str, DocFieldInfo] = None, store_conf: Optional[Dict] = None): super().__init__() if create_ui: lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead') - self._impls = Document._Impl(dataset_path, embed, create_ui or manager, server, name, launcher, store_conf) + self._impls = Document._Impl(dataset_path, embed, create_ui or manager, server, name, + launcher, store_conf, fields_info) self._curr_group = DocListManager.DEDAULT_GROUP_NAME def create_kb_group(self, name: str, store_conf: Optional[Dict] = None) -> "Document": diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index 31fa36d8..36322ac6 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -16,11 +16,9 @@ def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], index.remove(uids, group_name) class MapStore(StoreBase): - def __init__(self, node_groups: List[str], embed: Dict[str, Callable], **kwargs): + def __init__(self, embed: Dict[str, Callable], **kwargs): # Dict[group_name, Dict[uuid, DocNode]] - self._group2docs: Dict[str, Dict[str, DocNode]] = { - group: {} for group in node_groups - } + self._group2docs: Dict[str, Dict[str, DocNode]] = {} self._name2index = { 'default': DefaultIndex(embed, self), @@ -74,6 +72,10 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: return self._group2docs.keys() + @override + def add_group(self, name: str, embed_keys: Optional[List[str]] = None) -> None: + self._group2docs.setdefault(name, {}) + @override def query(self, *args, **kwargs) -> List[DocNode]: return self.get_index('default').query(*args, **kwargs) diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 86b32d81..5147140a 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -7,100 +7,49 @@ from .utils import parallel_do_embedding from .index_base import IndexBase from .store_base import StoreBase +from .doc_field_info import DocFieldInfo from lazyllm.common import override -class MilvusField: - DTYPE_VARCHAR = 0 - DTYPE_FLOAT_VECTOR = 1 - DTYPE_SPARSE_FLOAT_VECTOR = 2 - - def __init__(self, data_type: int, index_type: Optional[str] = None, - metric_type: Optional[str] = "", index_params: Dict = {}, - max_length: Optional[int] = None): - self.data_type = data_type - self.index_type = index_type - self.metric_type = metric_type - self.index_params = index_params - self.max_length = max_length - - class MilvusStore(StoreBase): + _primary_key = 'uid' + + _embedding_key_prefix = 'embedding_' + _field_key_prefix = 'field_' + + _builtin_fields = { + _primary_key: { + 'datatype': pymilvus.DataType.VARCHAR, + 'max_length': 256, + 'is_primary': True, + }, + 'text': { + 'datatype': pymilvus.DataType.VARCHAR, + 'max_length': True, + }, + 'parent': { + 'datatype': pymilvus.DataType.VARCHAR, + 'max_length': 256, + }, + } + _type2milvus = [ - pymilvus.DataType.VARCHAR, # DTYPE_VARCHAR - pymilvus.DataType.FLOAT_VECTOR, # DTYPE_FLOAT_VECTOR - pymilvus.DataType.SPARSE_FLOAT_VECTOR, # DTYPE_SPARSE_FLOAT_VECTOR + 0, + pymilvus.DataType.VARCHAR, ] - def __init__(self, uri: str, group_fields: Dict[str, Dict[str, Dict[str, MilvusField]]], - node_groups: List[str], embed: Dict[str, Callable], **kwargs): - new_copy = copy.copy(group_fields) - for g in node_groups: - if g not in new_copy: - new_copy[g] = {} - group_fields = new_copy + def __init__(self, embed: Dict[str, Callable], fields_info: Dict[str, DocFieldInfo], uri: str, + embedding_index_type: Optional[str] = None, embedding_metric_type: Optional[str] = None, + **kwargs): + self._embed = embed + self._fields_info = fields_info + self._embedding_index_type = embedding_index_type if embedding_index_type else 'HNSW' + self._embedding_metric_type = embedding_metric_type if embedding_metric_type else 'COSINE' - self._primary_key = 'uid' self._embedding_keys = embed.keys() - self._embed = embed + self._embed_dim = {k: len(e('a')) for k, e in embed.items()} self._client = MilvusClient(uri=uri) - embed_dim = {k: len(e('a')) for k, e in embed.items()} - builtin_fields = { - self._primary_key: { - 'datatype': pymilvus.DataType.VARCHAR, - 'max_length': 128, - 'is_primary': True, - }, - 'text': { - 'datatype': pymilvus.DataType.VARCHAR, - 'max_length': True, - }, - 'parent': { - 'datatype': pymilvus.DataType.VARCHAR, - 'max_length': 256, - }, - } - - for group_name, fields in group_fields.items(): - if group_name in self._client.list_collections(): - continue - - index_params = self._client.prepare_index_params() - schema = self._client.create_schema(auto_id=False, enable_dynamic_field=True) - - for name, field in builtin_fields.items(): - schema.add_field(field_name=name, **field) - - embedding_fields = fields.get('embedding', {}) - for name, field in embedding_fields.items(): - field_name = self._gen_embedding_key(name) - schema.add_field( - field_name=field_name, - datatype=self._type2milvus[field.data_type], - dim=embed_dim.get(name)) - if field.index_type is not None: - index_params.add_index(field_name=field_name, - index_type=field.index_type, - metric_type=field.metric_type, - params=field.index_params) - - metadata_fields = fields.get('metadata', {}) - for name, field in metadata_fields.items(): - field_name = self._gen_metadata_key(name) - schema.add_field( - field_name=field_name, - datatype=self._type2milvus[field.data_type], - max_length=field.max_length) - if field.index_type is not None: - index_params.add_index(field_name=field_name, - index_type=field.index_type, - metric_type=field.metric_type, - params=field.index_params) - - self._client.create_collection(collection_name=group_name, schema=schema, - index_params=index_params) - - self._map_store = MapStore(node_groups=list(group_fields.keys()), embed=embed) + self._map_store = MapStore(embed=embed) self._load_all_nodes_to(self._map_store) @override @@ -134,6 +83,30 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: return self._map_store.all_groups() + @override + def add_group(self, name: str, embed_keys: Optional[List[str]] = None) -> None: + if name in self._client.list_collections(): + return + + index_params = self._client.prepare_index_params() + schema = self._client.create_schema(auto_id=False, enable_dynamic_field=False) + + for key in embed_keys: + field_name = self._gen_embedding_key(key) + schema.add_field(field_name=field_name, datatype=pymilvus.DataType.FLOAT_VECTOR) + index_params.add_index(field_name=field_name, index_type=self._embedding_index_type, + metric_type=self._embedding_metric_type) + + if self._fields_info: + for key, info in self._fields_info.items(): + schema.add_field(field_name=self._gen_field_key(key), + datatype=self._type2milvus[info.data_type]) + + self._client.create_collection(collection_name=name, schema=schema, + index_params=index_params) + + self._map_store.add_group(name, embed_keys) + @override def register_index(self, type: str, index: IndexBase) -> None: self._map_store.register_index(type, index) @@ -173,16 +146,18 @@ def query(self, # ----- internal helper functions ----- # - @staticmethod - def _gen_embedding_key(k: str) -> str: - return 'embedding_' + k + @classmethod + def _gen_embedding_key(cls, k: str) -> str: + return cls._embedding_key_prefix + k - @staticmethod - def _gen_metadata_key(k: str) -> str: - return 'metadata_' + k + @classmethod + def _gen_field_key(cls, k: str) -> str: + return cls._field_key_prefix + k def _load_all_nodes_to(self, store: StoreBase): for group_name in self._client.list_collections(): + store.add_group(name=group_name, embed=self._embed) + results = self._client.query(collection_name=group_name, filter=f'{self._primary_key} != ""') for result in results: @@ -203,37 +178,32 @@ def _serialize_node_partial(self, node: DocNode) -> Dict: res = { 'uid': node.uid, 'text': node.text, + 'parent': node.parent.uid if node.parent else '', + 'metadata': node._metadata, } - if node.parent: - res['parent'] = node.parent.uid - else: - res['parent'] = '' - for k, v in node.embedding.items(): res[self._gen_embedding_key(k)] = v - for k, v in node.metadata.items(): - res[self._gen_metadata_key(k)] = v + for k, v in node.fields.items(): + res[self._gen_field_key(k)] = v return res def _deserialize_node_partial(self, result: Dict) -> DocNode: - ''' - without parent and children - ''' + record = copy.copy(result) + doc = DocNode( - uid=result.get('uid'), - text=result.get('text'), - parent=result.get('parent'), # this is the parent's uid + uid=record.pop('uid'), + text=record.pop('text'), + parent=record.pop('parent'), # this is the parent's uid + metadata=record.pop('metadata'), ) - for k in self._embedding_keys: - val = result.get(self._gen_embedding_key(k)) - if val: - doc.embedding[k] = val - for k in self._metadata_keys: - val = result.get(self._gen_metadata_key(k)) - if val: - doc._metadata[k] = val + for k, v in record.items(): + if k.startswith(self._embedding_key_prefix): + doc.embedding[k[len(self._embedding_key_prefix):]] = v + elif k.startswith(self._field_key_prefix): + if doc.parent: + doc._fields[k[len(self._field_key_prefix):]] = v return doc diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index 08992ad9..c40cb729 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -31,6 +31,10 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: pass + @abstractmethod + def add_group(self, name: str, embed_keys: Optional[List[str]] = None) -> None: + pass + @abstractmethod def query(self, *args, **kwargs) -> List[DocNode]: pass diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index 15c80727..e195c5ea 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -1,6 +1,7 @@ import os import shutil import hashlib +import concurrent from typing import List, Callable, Generator, Dict, Any, Optional, Union, Tuple from abc import ABC, abstractmethod from .index_base import IndexBase diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 6136b81c..eb4b5780 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -6,8 +6,9 @@ from lazyllm.tools.rag.store_base import LAZY_ROOT_NAME from lazyllm.tools.rag.map_store import MapStore from lazyllm.tools.rag.chroma_store import ChromadbStore -from lazyllm.tools.rag.milvus_store import MilvusStore, MilvusField +from lazyllm.tools.rag.milvus_store import MilvusStore from lazyllm.tools.rag.doc_node import DocNode +from lazyllm.tools.rag.doc_field_info import DocFieldInfo def clear_directory(directory_path): @@ -28,13 +29,16 @@ def clear_directory(directory_path): class TestChromadbStore(unittest.TestCase): def setUp(self): self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - self.embed_dim = {"default": 3} self.store_dir = tempfile.mkdtemp() self.mock_embed = { 'default': MagicMock(return_value=[1.0, 2.0, 3.0]), } - self.store = ChromadbStore(path=self.store_dir, node_groups=self.node_groups, - embed=self.mock_embed, embed_dim=self.embed_dim) + self.embed_dim = {"default": 3} + + self.store = ChromadbStore(dir=self.store_dir, embed=self.mock_embed, embed_dim=self.embed_dim) + for group in self.node_groups: + self.store.add_group(name=group, embed_keys=self.mock_embed.keys()) + self.store.update_nodes( [DocNode(uid="1", text="text1", group=LAZY_ROOT_NAME, parent=None)], ) @@ -176,32 +180,21 @@ def test_group_others(self): class TestMilvusStore(unittest.TestCase): def setUp(self): - fields = { - 'embedding': { - 'vec1': MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), - 'vec2': MilvusField(data_type=MilvusField.DTYPE_FLOAT_VECTOR, - index_type='HNSW', metric_type='COSINE'), - }, - 'metadata': { - 'comment': MilvusField(data_type=MilvusField.DTYPE_VARCHAR, max_length=128), - }, - } - group_fields = { - "group1": fields, - "group2": fields, - } - self.mock_embed = { 'vec1': MagicMock(return_value=[1.0, 2.0, 3.0]), 'vec2': MagicMock(return_value=[400.0, 500.0, 600.0, 700.0, 800.0]), } + self.fields_info = { + 'comment': DocFieldInfo(DocFieldInfo.DTYPE_VARCHAR), + } self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] _, self.store_file = tempfile.mkstemp(suffix=".db") - self.store = MilvusStore(uri=self.store_file, embed=self.mock_embed, - node_groups=self.node_groups, group_fields=group_fields) + self.store = MilvusStore(embed=self.mock_embed, fields_info=self.fields_info, + uri=self.store_file) + for group in self.node_groups: + self.store.add_group(name=group, embed_keys=self.mock_embed.keys()) self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, embedding={"vec1": [8.0, 9.0, 10.0], "vec2": [11.0, 12.0, 13.0, 14.0, 15.0]}, From 13a9197af25d1ffd67fa864f7941e4c7522414e3 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 5 Nov 2024 19:47:02 +0800 Subject: [PATCH 55/60] review2 --- lazyllm/tools/rag/chroma_store.py | 15 ++-- lazyllm/tools/rag/doc_field_desc.py | 8 ++ lazyllm/tools/rag/doc_field_info.py | 6 -- lazyllm/tools/rag/doc_impl.py | 20 ++--- lazyllm/tools/rag/doc_node.py | 7 +- lazyllm/tools/rag/document.py | 31 ++++--- lazyllm/tools/rag/map_store.py | 10 +-- lazyllm/tools/rag/milvus_store.py | 105 +++++++++++++---------- lazyllm/tools/rag/retriever.py | 2 + lazyllm/tools/rag/store_base.py | 4 - lazyllm/tools/rag/web.py | 124 +++++++++------------------- tests/basic_tests/test_document.py | 11 +++ tests/basic_tests/test_store.py | 14 ++-- 13 files changed, 172 insertions(+), 185 deletions(-) create mode 100644 lazyllm/tools/rag/doc_field_desc.py delete mode 100644 lazyllm/tools/rag/doc_field_info.py diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index c6d57c5a..9422890f 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -14,17 +14,21 @@ # ---------------------------------------------------------------------------- # class ChromadbStore(StoreBase): - def __init__(self, dir: str, embed: Dict[str, Callable], embed_dim: Dict[str, int], **kwargs) -> None: + def __init__(self, dir: str, node_groups: List[str], embed: Dict[str, Callable], + embed_dim: Dict[str, int], **kwargs) -> None: self._db_client = chromadb.PersistentClient(path=dir) LOG.success(f"Initialzed chromadb in path: {dir}") - self._collections: Dict[str, Collection] = {} + self._collections: Dict[str, Collection] = { + group: self._db_client.get_or_create_collection(group) + for group in node_groups + } self._name2index = { 'default': DefaultIndex(embed, self._map_store), 'file_node_map': _FileNodeIndex(), } - self._map_store = MapStore(embed=embed) + self._map_store = MapStore(node_groups=node_groups, embed=embed) self._load_store(embed_dim) @override @@ -52,11 +56,6 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: return self._map_store.all_groups() - @override - def add_group(self, name: str, embed_keys: Optional[List[str]] = None) -> None: - self._collections[name] = self._db_client.get_or_create_collection(name) - self._map_store.add_group(name, embed_keys) - @override def query(self, *args, **kwargs) -> List[DocNode]: return self.get_index('default').query(*args, **kwargs) diff --git a/lazyllm/tools/rag/doc_field_desc.py b/lazyllm/tools/rag/doc_field_desc.py new file mode 100644 index 00000000..ffb416db --- /dev/null +++ b/lazyllm/tools/rag/doc_field_desc.py @@ -0,0 +1,8 @@ +from typing import Optional + +class DocFieldDesc: + DTYPE_VARCHAR = 0 + + def __init__(self, data_type: int = DTYPE_VARCHAR, max_length: Optional[int] = 65535): + self.data_type = data_type + self.max_length = max_length diff --git a/lazyllm/tools/rag/doc_field_info.py b/lazyllm/tools/rag/doc_field_info.py deleted file mode 100644 index 3d233000..00000000 --- a/lazyllm/tools/rag/doc_field_info.py +++ /dev/null @@ -1,6 +0,0 @@ -class DocFieldInfo: - DTYPE_UNKNOWN = 0 - DTYPE_VARCHAR = 1 - - def __init__(self, data_type: DTYPE_UNKNOWN): - self.data_type = data_type diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 6b00ebeb..5252ea8c 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -13,7 +13,7 @@ from .doc_node import DocNode from .data_loaders import DirectoryReader from .utils import DocListManager -from .doc_field_info import DocFieldInfo +from .doc_field_desc import DocFieldDesc import threading import time @@ -38,7 +38,7 @@ class DocImpl: def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = None, doc_files: Optional[str] = None, kb_group_name: Optional[str] = None, - fields_info: Dict[str, DocFieldInfo] = None, store_conf: Optional[Dict] = None): + fields_desc: Dict[str, DocFieldDesc] = None, store_conf: Optional[Dict] = None): super().__init__() assert (dlm is None) ^ (doc_files is None), 'Only one of dataset_path or doc_files should be provided' self._local_file_reader: Dict[str, Callable] = {} @@ -48,8 +48,9 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N self.node_groups: Dict[str, Dict] = {LAZY_ROOT_NAME: {}} self.embed = {k: embed_wrapper(e) for k, e in embed.items()} self._embed_dim = None - self._fields_info = fields_info + self._fields_desc = fields_desc self.store = store_conf # NOTE: will be initialized in _lazy_init() + self._activated_embeddings = set() @once_wrapper(reset_on_pickle=True) def _lazy_init(self) -> None: @@ -71,7 +72,6 @@ def _lazy_init(self) -> None: raise ValueError(f'store type [{type(self.store)}] is not a dict.') if not self.store.is_group_active(LAZY_ROOT_NAME): - self.store.add_group(name=LAZY_ROOT_NAME, fields_info=self._fields_info, embed_keys={}) ids, pathes = self._list_files() root_nodes = self._reader.load_data(pathes) self.store.update_nodes(root_nodes) @@ -94,11 +94,14 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase: raise ValueError('`kwargs` in store conf is not a dict.') if store_type == "map": - store = MapStore(embed=self.embed, **kwargs) + store = MapStore(node_groups=self.node_groups, embed=self.embed, **kwargs) elif store_type == "chroma": - store = ChromadbStore(embed_dim=self.embed_dim, embed=self.embed, **kwargs) + store = ChromadbStore(node_groups=self.node_groups, embed=self.embed, + embed_dim=self.embed_dim, **kwargs) elif store_type == "milvus": - store = MilvusStore(embed=self.embed, fields_info=self._fields_info, **kwargs) + store = MilvusStore(node_groups=self.node_groups, embed=self.embed, + embed_keys=self._activated_embeddings, + fields_desc=self._fields_desc, **kwargs) else: raise NotImplementedError( f"Not implemented store type for {store_type}" @@ -277,9 +280,6 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_ index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]: self._lazy_init() - if not self.store.is_group_active(group_name): - self.store.add_group(group_name, embed_keys=embed_keys) - if type is None or type == 'default': return self.store.query(query=query, group_name=group_name, similarity_name=similarity, similarity_cut_off=similarity_cut_off, topk=topk, diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 2510722f..6516a3c1 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -30,13 +30,12 @@ def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: self.parent: Optional["DocNode"] = parent self.children: Dict[str, List["DocNode"]] = defaultdict(list) self.is_saved: bool = False - self._docpath = None self._lock = threading.Lock() self._embedding_state = set() if fields and parent: raise ValueError('only ROOT node can set fields.') - self._fields = fields + self._fields = fields if fields else {} @property def root_node(self) -> Optional["DocNode"]: @@ -76,12 +75,12 @@ def excluded_llm_metadata_keys(self, excluded_llm_metadata_keys: List) -> None: @property def docpath(self) -> str: - return self.root_node._docpath or '' + return self.root_node._fields.get('lazyllm_doc_path', '') @docpath.setter def docpath(self, path): assert not self.parent, 'Only root node can set docpath' - self._docpath = str(path) + self._fields['lazyllm_doc_path'] = str(path) def get_children_str(self) -> str: return str( diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index 2030362c..6f55ff7a 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -10,7 +10,8 @@ from .doc_node import DocNode from .store_base import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY from .utils import DocListManager -from .doc_field_info import DocFieldInfo +from .doc_field_desc import DocFieldDesc +from .web import DocWebModule import copy import functools @@ -22,9 +23,9 @@ def __call__(self, cls, *args, **kw): class Document(ModuleBase): class _Impl(ModuleBase): def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, - manager: bool = False, server: bool = False, name: Optional[str] = None, + manager: Union[bool, str] = False, server: bool = False, name: Optional[str] = None, launcher: Optional[Launcher] = None, store_conf: Optional[Dict] = None, - fields_info: Optional[Dict[str, DocFieldInfo]] = None): + fields_desc: Optional[Dict[str, DocFieldDesc]] = None): super().__init__() if not os.path.exists(dataset_path): defatult_path = os.path.join(lazyllm.config["data_path"], dataset_path) @@ -39,37 +40,43 @@ def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, self._submodules.append(embed) self._dlm = DocListManager(dataset_path, name).init_tables() self._kbs = CallableDict({DocListManager.DEDAULT_GROUP_NAME: - DocImpl(embed=self._embed, dlm=self._dlm, store_conf=store_conf)}) + DocImpl(embed=self._embed, dlm=self._dlm, fields_desc=fields_desc, + store_conf=store_conf)}) if manager: self._manager = ServerModule(DocManager(self._dlm)) + if manager == 'ui': self._docweb = DocWebModule(doc_server=self._manager) if server: self._kbs = ServerModule(self._kbs) - self._fields_info = fields_info + self._fields_desc = fields_desc - def add_kb_group(self, name, store_conf: Optional[Dict] = None): + def add_kb_group(self, name, fields_desc: Optional[Dict[str, DocFieldDesc]] = None, + store_conf: Optional[Dict] = None): if isinstance(self._kbs, ServerModule): self._kbs._impl._m[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name, - store_conf=store_conf) + fields_desc=fields_desc, store_conf=store_conf) else: self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name, - store_conf=store_conf) + fields_desc=fields_desc, store_conf=store_conf) self._dlm.add_kb_group(name) def get_doc_by_kb_group(self, name): return self._kbs._impl._m[name] if isinstance(self._kbs, ServerModule) else self._kbs[name] - def stop(self): self._launcher.cleanup() + def stop(self): + if hasattr(self, '_docweb'): + self._docweb.stop() + self._launcher.cleanup() def __call__(self, *args, **kw): return self._kbs(*args, **kw) def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, - create_ui: bool = False, manager: bool = False, server: bool = False, + create_ui: bool = False, manager: Union[bool, str] = False, server: bool = False, name: Optional[str] = None, launcher: Optional[Launcher] = None, - fields_info: Dict[str, DocFieldInfo] = None, store_conf: Optional[Dict] = None): + fields_desc: Dict[str, DocFieldDesc] = None, store_conf: Optional[Dict] = None): super().__init__() if create_ui: lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead') self._impls = Document._Impl(dataset_path, embed, create_ui or manager, server, name, - launcher, store_conf, fields_info) + launcher, store_conf, fields_desc) self._curr_group = DocListManager.DEDAULT_GROUP_NAME def create_kb_group(self, name: str, store_conf: Optional[Dict] = None) -> "Document": diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index 36322ac6..31fa36d8 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -16,9 +16,11 @@ def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], index.remove(uids, group_name) class MapStore(StoreBase): - def __init__(self, embed: Dict[str, Callable], **kwargs): + def __init__(self, node_groups: List[str], embed: Dict[str, Callable], **kwargs): # Dict[group_name, Dict[uuid, DocNode]] - self._group2docs: Dict[str, Dict[str, DocNode]] = {} + self._group2docs: Dict[str, Dict[str, DocNode]] = { + group: {} for group in node_groups + } self._name2index = { 'default': DefaultIndex(embed, self), @@ -72,10 +74,6 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: return self._group2docs.keys() - @override - def add_group(self, name: str, embed_keys: Optional[List[str]] = None) -> None: - self._group2docs.setdefault(name, {}) - @override def query(self, *args, **kwargs) -> List[DocNode]: return self.get_index('default').query(*args, **kwargs) diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 5147140a..10343e6a 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -7,8 +7,10 @@ from .utils import parallel_do_embedding from .index_base import IndexBase from .store_base import StoreBase -from .doc_field_info import DocFieldInfo +from .doc_field_desc import DocFieldDesc from lazyllm.common import override +import pickle +import base64 class MilvusStore(StoreBase): _primary_key = 'uid' @@ -22,34 +24,71 @@ class MilvusStore(StoreBase): 'max_length': 256, 'is_primary': True, }, + 'parent': { + 'datatype': pymilvus.DataType.VARCHAR, + 'max_length': 256, + }, 'text': { 'datatype': pymilvus.DataType.VARCHAR, - 'max_length': True, + 'max_length': 65535, }, - 'parent': { + 'metadata': { 'datatype': pymilvus.DataType.VARCHAR, - 'max_length': 256, + 'max_length': 65535, }, } _type2milvus = [ - 0, pymilvus.DataType.VARCHAR, ] - def __init__(self, embed: Dict[str, Callable], fields_info: Dict[str, DocFieldInfo], uri: str, - embedding_index_type: Optional[str] = None, embedding_metric_type: Optional[str] = None, - **kwargs): + def __init__(self, node_groups: List[str], embed: Dict[str, Callable], embed_keys: List[str], + fields_desc: Dict[str, DocFieldDesc], uri: str, embedding_index_type: Optional[str] = None, + embedding_metric_type: Optional[str] = None, **kwargs): self._embed = embed - self._fields_info = fields_info - self._embedding_index_type = embedding_index_type if embedding_index_type else 'HNSW' - self._embedding_metric_type = embedding_metric_type if embedding_metric_type else 'COSINE' - - self._embedding_keys = embed.keys() - self._embed_dim = {k: len(e('a')) for k, e in embed.items()} self._client = MilvusClient(uri=uri) - self._map_store = MapStore(embed=embed) + if not embedding_index_type: + embedding_index_type = 'HNSW' + + if not embedding_metric_type: + embedding_metric_type = 'COSINE' + + embed_dims = {} + for k in embed_keys: + e = embed.get(k) + if not e: + raise ValueError(f'cannot find embed callable [{k}]') + embed_dims[k] = len(e('a')) + + for group in node_groups: + index_params = self._client.prepare_index_params() + schema = self._client.create_schema(auto_id=False, enable_dynamic_field=False) + + for key, info in self._builtin_fields.items(): + schema.add_field(field_name=key, **info) + + for key in embed_keys: + dim = embed_dims.get(key) + if not dim: + raise ValueError(f'cannot find embedding dim of embed [{key}]') + + field_name = self._gen_embedding_key(key) + schema.add_field(field_name=field_name, datatype=pymilvus.DataType.FLOAT_VECTOR, + dim=dim) + index_params.add_index(field_name=field_name, index_type=embedding_index_type, + metric_type=embedding_metric_type) + + if fields_desc: + for key, info in fields_desc.items(): + schema.add_field(field_name=self._gen_field_key(key), + datatype=self._type2milvus[info.data_type], + max_length=info.max_length) + + self._client.create_collection(collection_name=group, schema=schema, + index_params=index_params) + + self._map_store = MapStore(node_groups=node_groups, embed=embed) self._load_all_nodes_to(self._map_store) @override @@ -83,30 +122,6 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: return self._map_store.all_groups() - @override - def add_group(self, name: str, embed_keys: Optional[List[str]] = None) -> None: - if name in self._client.list_collections(): - return - - index_params = self._client.prepare_index_params() - schema = self._client.create_schema(auto_id=False, enable_dynamic_field=False) - - for key in embed_keys: - field_name = self._gen_embedding_key(key) - schema.add_field(field_name=field_name, datatype=pymilvus.DataType.FLOAT_VECTOR) - index_params.add_index(field_name=field_name, index_type=self._embedding_index_type, - metric_type=self._embedding_metric_type) - - if self._fields_info: - for key, info in self._fields_info.items(): - schema.add_field(field_name=self._gen_field_key(key), - datatype=self._type2milvus[info.data_type]) - - self._client.create_collection(collection_name=name, schema=schema, - index_params=index_params) - - self._map_store.add_group(name, embed_keys) - @override def register_index(self, type: str, index: IndexBase) -> None: self._map_store.register_index(type, index) @@ -156,7 +171,7 @@ def _gen_field_key(cls, k: str) -> str: def _load_all_nodes_to(self, store: StoreBase): for group_name in self._client.list_collections(): - store.add_group(name=group_name, embed=self._embed) + store.activate_group(name=group_name, embed=self._embed) results = self._client.query(collection_name=group_name, filter=f'{self._primary_key} != ""') @@ -179,13 +194,15 @@ def _serialize_node_partial(self, node: DocNode) -> Dict: 'uid': node.uid, 'text': node.text, 'parent': node.parent.uid if node.parent else '', - 'metadata': node._metadata, + 'metadata': base64.b64encode(pickle.dumps(node._metadata)).decode('utf-8'), } for k, v in node.embedding.items(): res[self._gen_embedding_key(k)] = v - for k, v in node.fields.items(): - res[self._gen_field_key(k)] = v + + if node.parent and node.fields: + for k, v in node.fields.items(): + res[self._gen_field_key(k)] = v return res @@ -196,7 +213,7 @@ def _deserialize_node_partial(self, result: Dict) -> DocNode: uid=record.pop('uid'), text=record.pop('text'), parent=record.pop('parent'), # this is the parent's uid - metadata=record.pop('metadata'), + metadata=pickle.loads(base64.b64decode(record.pop('metadata').encode('utf-8'))), ) for k, v in record.items(): diff --git a/lazyllm/tools/rag/retriever.py b/lazyllm/tools/rag/retriever.py index 0bfee49a..d3568d5d 100644 --- a/lazyllm/tools/rag/retriever.py +++ b/lazyllm/tools/rag/retriever.py @@ -49,6 +49,8 @@ def __init__( for doc in self._docs: assert isinstance(doc, Document), 'Only Document or List[Document] are supported' self._submodules.append(doc) + if embed_keys: + doc._activated_embeddings.insert(embed_keys) self._group_name = group_name self._similarity = similarity # similarity function str diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index c40cb729..08992ad9 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -31,10 +31,6 @@ def is_group_active(self, name: str) -> bool: def all_groups(self) -> List[str]: pass - @abstractmethod - def add_group(self, name: str, embed_keys: Optional[List[str]] = None) -> None: - pass - @abstractmethod def query(self, *args, **kwargs) -> List[DocNode]: pass diff --git a/lazyllm/tools/rag/web.py b/lazyllm/tools/rag/web.py index aabbd126..fb4b75cc 100644 --- a/lazyllm/tools/rag/web.py +++ b/lazyllm/tools/rag/web.py @@ -1,10 +1,11 @@ import os import socket import requests -import multiprocessing import json +from typing import Union import lazyllm +from lazyllm import LOG from lazyllm import ModuleBase, ServerModule import gradio as gr from lazyllm.flow import Pipeline @@ -51,7 +52,7 @@ def delete_group(self, group_name: str): def list_groups(self): response = requests.get( - f"{self.base_url}/list_groups", headers=self.basic_headers(False) + f"{self.base_url}/list_kb_groups", headers=self.basic_headers(False) ) return response.json()["data"] @@ -62,38 +63,34 @@ def upload_files(self, group_name: str, override: bool = True): ) return response.json()["data"] - def list_files(self, group_name: str): + def list_files_in_group(self, group_name: str): response = requests.get( - f"{self.base_url}/list_files?group_name={group_name}", + f"{self.base_url}/list_files_in_group?group_name={group_name}&alive=True", headers=self.basic_headers(False), ) return response.json()["data"] - def delete_file(self, group_name: str, file_name: str): + def delete_file(self, group_name: str, file_ids: list[str]): response = requests.post( - f"{self.base_url}/delete_file?group_name={group_name}&file_name={file_name}", + f"{self.base_url}/delete_files_from_group", headers=self.basic_headers(True), + json={"group_name": group_name, "file_ids": file_ids} ) return response.json()["msg"] - def gr_show_list(self, str_list: list[str], list_name: str): - return gr.DataFrame( - headers=["index", list_name], - value=[[index, str_list[index]] for index in range(len(str_list))], - ) + def gr_show_list(self, str_list: list, list_name: Union[str, list]): + if isinstance(list_name, str): + headers = ["index", list_name] + value = [[index, str_list[index]] for index in range(len(str_list))] + else: + headers = ["index"] + list_name + value = [[index] + str_list[index:index + len(list_name)] for index in range(len(str_list))] + return gr.DataFrame(headers=headers, value=value) def create_ui(self): - with gr.Blocks() as demo: + with gr.Blocks(analytics_enabled=False) as demo: with gr.Tabs(): select_group_list = [] - with gr.TabItem("创建分组"): - create_group_text = gr.Textbox(label="分组名称:") - create_group_btn = gr.Button("创建") - - with gr.TabItem("删除分组"): - del_select_group = gr.Dropdown(self.list_groups(), label="选择分组") - delete_group_btn = gr.Button("删除") - select_group_list.append(del_select_group) with gr.TabItem("分组列表"): select_group = self.gr_show_list( @@ -110,13 +107,13 @@ def _upload_files(group_name, files): for file in files ] - url = f"{self.base_url}/upload_files?group_name={group_name}&override=true" + url = f"{self.base_url}/add_files_to_group?group_name={group_name}&override=true" response = requests.post( url, files=files_to_upload, headers=self.muti_headers() ) response.raise_for_status() response_data = response.json() - gr.Info(str(response_data["data"])) + gr.Info(str(response_data["msg"])) for _, (_, file_obj) in files_to_upload: file_obj.close() @@ -134,19 +131,16 @@ def _upload_files(group_name, files): select_group_list.append(select_group) - with gr.TabItem("文件列表"): - + with gr.TabItem("分组文件列表"): def _list_group_files(group_name): - file_list = self.list_files(group_name) + file_list = self.list_files_in_group(group_name) + values = [[i] + file_list[i][:2] for i in range(len(file_list))] return gr.update( - value=[ - [index, file_list[index]] - for index in range(len(file_list)) - ] + value=values ) select_group = gr.Dropdown(self.list_groups(), label="选择分组") - show_list = self.gr_show_list([], list_name="file_name") + show_list = self.gr_show_list([], list_name=["file_id", "file_name"]) select_group.change( fn=_list_group_files, inputs=select_group, outputs=show_list ) @@ -155,7 +149,8 @@ def _list_group_files(group_name): with gr.TabItem("删除文件"): def _list_group_files(group_name): - file_list = self.list_files(group_name) + file_list = self.list_files_in_group(group_name) + file_list = [','.join(file[:2]) for file in file_list] return gr.update(choices=file_list) select_group = gr.Dropdown(self.list_groups(), label="选择分组") @@ -165,8 +160,9 @@ def _list_group_files(group_name): ) delete_btn = gr.Button("删除") - def _delete_file(group_name, file_name): - gr.Info(self.delete_file(group_name, file_name)) + def _delete_file(group_name, select_file): + file_ids = [select_file.split(',')[0]] + gr.Info(self.delete_file(group_name, file_ids)) return _list_group_files(group_name) delete_btn.click( @@ -176,50 +172,6 @@ def _delete_file(group_name, file_name): ) select_group_list.append(select_group) - def _create_group(group_name): - gr.Info(self.new_group(group_name)) - curt_groups = self.list_groups() - return [ - ( - gr.update(choices=curt_groups) - if isinstance(select, gr.Dropdown) - else gr.update( - value=[ - [index, curt_groups[index]] - for index in range(len(curt_groups)) - ] - ) - ) - for select in select_group_list - ] - - def _del_group(group_name): - gr.Info(self.delete_group(group_name)) - curt_groups = self.list_groups() - return [ - ( - gr.update(choices=curt_groups) - if isinstance(select, gr.Dropdown) - else gr.update( - value=[ - [index, curt_groups[index]] - for index in range(len(curt_groups)) - ] - ) - ) - for select in select_group_list - ] - - create_group_btn.click( - fn=_create_group, - inputs=create_group_text, - outputs=select_group_list, - ) - - delete_group_btn.click( - fn=_del_group, inputs=del_select_group, outputs=select_group_list - ) - return demo @@ -264,15 +216,13 @@ def _work(self): port = self.port assert self._verify_port_access(port), f"port {port} is occupied" - def _impl(): - self.demo.queue().launch(server_name="0.0.0.0", server_port=port) - self.api_url = self.doc_server._url.rsplit("/", 1)[0] self.web_ui = WebUi(self.api_url) self.demo = self.web_ui.create_ui() - self.p = multiprocessing.Process(target=_impl) - self.p.start() - self.url = f"http://0.0.0.0:{port}" + self.url = f'http://0.0.0.0:{port}' + + self.demo.queue().launch(server_name="0.0.0.0", server_port=port, prevent_thread_lock=True) + LOG.success(f'LazyLLM docwebmodule launched successfully: Running on local URL: {self.url}', flush=True) def _get_deploy_tasks(self): return Pipeline(self._work) @@ -281,7 +231,13 @@ def _get_post_process_tasks(self): return Pipeline(self._print_url) def wait(self): - return self.p.join() + self.demo.block_thread() + + def stop(self): + if self.demo: + self.demo.close() + del self.demo + self.demo, self.url = None, '' def _find_can_use_network_port(self): for port in self.port: diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index 0ba38050..11a33585 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -153,6 +153,17 @@ def test_multi_embedding_with_document(self): nodes3 = retriever3("何为天道?") assert len(nodes3) == 3 + def test_doc_web_module(self): + import time + import requests + doc = Document('rag_master', manager='ui') + doc.create_kb_group(name="test_group") + doc.start() + time.sleep(4) + url = doc._impls._docweb.url + response = requests.get(url) + assert response.status_code == 200 + doc.stop() class TestFileNodeIndex(unittest.TestCase): def setUp(self): diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index eb4b5780..47c78943 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -8,7 +8,7 @@ from lazyllm.tools.rag.chroma_store import ChromadbStore from lazyllm.tools.rag.milvus_store import MilvusStore from lazyllm.tools.rag.doc_node import DocNode -from lazyllm.tools.rag.doc_field_info import DocFieldInfo +from lazyllm.tools.rag.doc_field_desc import DocFieldDesc def clear_directory(directory_path): @@ -37,7 +37,7 @@ def setUp(self): self.store = ChromadbStore(dir=self.store_dir, embed=self.mock_embed, embed_dim=self.embed_dim) for group in self.node_groups: - self.store.add_group(name=group, embed_keys=self.mock_embed.keys()) + self.store.activate_group(name=group, embed_keys=self.mock_embed.keys()) self.store.update_nodes( [DocNode(uid="1", text="text1", group=LAZY_ROOT_NAME, parent=None)], @@ -184,21 +184,21 @@ def setUp(self): 'vec1': MagicMock(return_value=[1.0, 2.0, 3.0]), 'vec2': MagicMock(return_value=[400.0, 500.0, 600.0, 700.0, 800.0]), } - self.fields_info = { - 'comment': DocFieldInfo(DocFieldInfo.DTYPE_VARCHAR), + self.fields_desc = { + 'comment': DocFieldDesc(data_type=DocFieldDesc.DTYPE_VARCHAR), } self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] _, self.store_file = tempfile.mkstemp(suffix=".db") - self.store = MilvusStore(embed=self.mock_embed, fields_info=self.fields_info, + self.store = MilvusStore(embed=self.mock_embed, fields_desc=self.fields_desc, uri=self.store_file) for group in self.node_groups: - self.store.add_group(name=group, embed_keys=self.mock_embed.keys()) + self.store.activate_group(name=group, embed_keys=self.mock_embed.keys()) self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, embedding={"vec1": [8.0, 9.0, 10.0], "vec2": [11.0, 12.0, 13.0, 14.0, 15.0]}, - metadata={'comment': 'comment1'}) + metadata={'comment': 'comment1'}, fields={'comment': 'comment3'}) self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, embedding={"vec1": [100.0, 200.0, 300.0], "vec2": [400.0, 500.0, 600.0, 700.0, 800.0]}, metadata={'comment': 'comment2'}) From f0cdd57cecde1c53c06082a7c0b1964b2e88ad88 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Wed, 6 Nov 2024 09:57:57 +0800 Subject: [PATCH 56/60] s --- lazyllm/tools/rag/milvus_store.py | 7 ++----- tests/basic_tests/test_store.py | 7 +++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 10343e6a..451f78bd 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -171,8 +171,6 @@ def _gen_field_key(cls, k: str) -> str: def _load_all_nodes_to(self, store: StoreBase): for group_name in self._client.list_collections(): - store.activate_group(name=group_name, embed=self._embed) - results = self._client.query(collection_name=group_name, filter=f'{self._primary_key} != ""') for result in results: @@ -200,9 +198,8 @@ def _serialize_node_partial(self, node: DocNode) -> Dict: for k, v in node.embedding.items(): res[self._gen_embedding_key(k)] = v - if node.parent and node.fields: - for k, v in node.fields.items(): - res[self._gen_field_key(k)] = v + for k, v in node.fields.items(): + res[self._gen_field_key(k)] = v return res diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 47c78943..67d032ea 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -191,10 +191,9 @@ def setUp(self): self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] _, self.store_file = tempfile.mkstemp(suffix=".db") - self.store = MilvusStore(embed=self.mock_embed, fields_desc=self.fields_desc, - uri=self.store_file) - for group in self.node_groups: - self.store.activate_group(name=group, embed_keys=self.mock_embed.keys()) + self.store = MilvusStore(node_groups=self.node_groups, embed=self.mock_embed, + embed_keys=self.mock_embed.keys(), + fields_desc=self.fields_desc, uri=self.store_file) self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, embedding={"vec1": [8.0, 9.0, 10.0], "vec2": [11.0, 12.0, 13.0, 14.0, 15.0]}, From 581d30b216037202d398b09b0f2eca4b51ba557c Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Wed, 6 Nov 2024 10:42:39 +0800 Subject: [PATCH 57/60] review3 --- lazyllm/tools/rag/chroma_store.py | 16 +++++++++------- lazyllm/tools/rag/doc_impl.py | 2 +- tests/basic_tests/test_document.py | 1 + tests/basic_tests/test_store.py | 9 ++++----- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index 9422890f..77770e86 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -10,6 +10,7 @@ from .default_index import DefaultIndex from .map_store import MapStore import pickle +import base64 # ---------------------------------------------------------------------------- # @@ -23,14 +24,14 @@ def __init__(self, dir: str, node_groups: List[str], embed: Dict[str, Callable], for group in node_groups } + self._map_store = MapStore(node_groups=node_groups, embed=embed) + self._load_store(embed_dim) + self._name2index = { 'default': DefaultIndex(embed, self._map_store), 'file_node_map': _FileNodeIndex(), } - self._map_store = MapStore(node_groups=node_groups, embed=embed) - self._load_store(embed_dim) - @override def update_nodes(self, nodes: List[DocNode]) -> None: self._map_store.update_nodes(nodes) @@ -132,13 +133,14 @@ def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dim: Dict[str chroma_metadata = results['metadatas'][i] parent = chroma_metadata['parent'] - fields = pickle.loads(chroma_metadata['fields']) if parent else None + fields = pickle.loads(base64.b64decode(chroma_metadata['fields'].encode('utf-8')))\ + if parent else None node = DocNode( uid=uid, text=results["documents"][i], group=chroma_metadata["group"], - embedding=pickle.loads(chroma_metadata['embedding']), + embedding=pickle.loads(base64.b64decode(chroma_metadata['embedding'].encode('utf-8'))), parent=parent, fields=fields, ) @@ -167,11 +169,11 @@ def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]: metadata = { "group": node.group, "parent": node.parent.uid if node.parent else "", - "embedding": pickle.dumps(node.embedding), + "embedding": base64.b64encode(pickle.dumps(node.embedding)).decode('utf-8'), } if node.parent: - metadata["fields"] = pickle.dumps(node.fields) + metadata["fields"] = base64.b64encode(pickle.dumps(node.fields)).decode('utf-8') return metadata diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 5252ea8c..30015d1b 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -219,7 +219,7 @@ def _add_files(self, input_files: List[str]): if len(input_files) == 0: return root_nodes = self._reader.load_data(input_files) - temp_store = self._create_store("map") + temp_store = self._create_store({"type": "map"}) temp_store.update_nodes(root_nodes) all_groups = self.store.all_groups() LOG.info(f"add_files: Trying to merge store with {all_groups}") diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index 11a33585..141df678 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -69,6 +69,7 @@ def test_add_files(self): assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 2 def test_delete_files(self): + self.doc_impl._lazy_init() self.doc_impl._delete_files(["dummy_file.txt"]) assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 0 diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 67d032ea..ab78bfe6 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -35,9 +35,8 @@ def setUp(self): } self.embed_dim = {"default": 3} - self.store = ChromadbStore(dir=self.store_dir, embed=self.mock_embed, embed_dim=self.embed_dim) - for group in self.node_groups: - self.store.activate_group(name=group, embed_keys=self.mock_embed.keys()) + self.store = ChromadbStore(dir=self.store_dir, node_groups=self.node_groups, + embed=self.mock_embed, embed_dim=self.embed_dim) self.store.update_nodes( [DocNode(uid="1", text="text1", group=LAZY_ROOT_NAME, parent=None)], @@ -75,7 +74,7 @@ def test_load_store(self): # Reset store and load from "persistent" storage self.store._map_store._group2docs = {group: {} for group in self.node_groups} - self.store._load_store() + self.store._load_store(self.embed_dim) nodes = self.store.get_nodes("group1") self.assertEqual(len(nodes), 2) @@ -93,7 +92,7 @@ def test_insert_dict_as_sparse_embedding(self): self.store.update_nodes([node1, node2]) results = self.store._peek_all_documents('group1') - nodes = self.store._build_nodes_from_chroma(results) + nodes = self.store._build_nodes_from_chroma(results, self.embed_dim) nodes_dict = { node.uid: node for node in nodes } From 3f3257fa1dc1bf4aae472802b2e87fb70bcc0d2a Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Wed, 6 Nov 2024 15:14:36 +0800 Subject: [PATCH 58/60] review4 --- lazyllm/tools/rag/chroma_store.py | 17 +++++++------- lazyllm/tools/rag/doc_field_desc.py | 2 +- lazyllm/tools/rag/doc_impl.py | 21 +++++++++-------- lazyllm/tools/rag/map_store.py | 4 ++-- lazyllm/tools/rag/milvus_store.py | 18 +++++---------- lazyllm/tools/rag/retriever.py | 2 +- tests/basic_tests/test_store.py | 35 ++++++++++++++++++++--------- 7 files changed, 56 insertions(+), 43 deletions(-) diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index 77770e86..0002f870 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Callable +from typing import Any, Dict, List, Optional, Callable, Set import chromadb from lazyllm import LOG from lazyllm.common import override @@ -15,17 +15,18 @@ # ---------------------------------------------------------------------------- # class ChromadbStore(StoreBase): - def __init__(self, dir: str, node_groups: List[str], embed: Dict[str, Callable], - embed_dim: Dict[str, int], **kwargs) -> None: + def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Callable], + embed_dims: Dict[str, int], dir: str, **kwargs) -> None: self._db_client = chromadb.PersistentClient(path=dir) LOG.success(f"Initialzed chromadb in path: {dir}") + node_groups = list(group_embed_keys.keys()) self._collections: Dict[str, Collection] = { group: self._db_client.get_or_create_collection(group) for group in node_groups } self._map_store = MapStore(node_groups=node_groups, embed=embed) - self._load_store(embed_dim) + self._load_store(embed_dims) self._name2index = { 'default': DefaultIndex(embed, self._map_store), @@ -71,7 +72,7 @@ def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]: type = 'default' return self._name2index.get(type) - def _load_store(self, embed_dim: Dict[str, int]) -> None: + def _load_store(self, embed_dims: Dict[str, int]) -> None: if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: LOG.info("No persistent data found, skip the rebuilding phrase.") return @@ -79,7 +80,7 @@ def _load_store(self, embed_dim: Dict[str, int]) -> None: # Restore all nodes for group in self._collections.keys(): results = self._peek_all_documents(group) - nodes = self._build_nodes_from_chroma(results, embed_dim) + nodes = self._build_nodes_from_chroma(results, embed_dims) self._map_store.update_nodes(nodes) # Rebuild relationships @@ -127,7 +128,7 @@ def _delete_group_nodes(self, group_name: str, uids: List[str]) -> None: if collection: collection.delete(ids=uids) - def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dim: Dict[str, int]) -> List[DocNode]: + def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dims: Dict[str, int]) -> List[DocNode]: nodes: List[DocNode] = [] for i, uid in enumerate(results['ids']): chroma_metadata = results['metadatas'][i] @@ -150,7 +151,7 @@ def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dim: Dict[str new_embedding_dict = {} for key, embedding in node.embedding.items(): if isinstance(embedding, dict): - dim = embed_dim.get(key) + dim = embed_dims.get(key) if not dim: raise ValueError(f'dim of embed [{key}] is not determined.') new_embedding = [0] * dim diff --git a/lazyllm/tools/rag/doc_field_desc.py b/lazyllm/tools/rag/doc_field_desc.py index ffb416db..fc71d9a1 100644 --- a/lazyllm/tools/rag/doc_field_desc.py +++ b/lazyllm/tools/rag/doc_field_desc.py @@ -3,6 +3,6 @@ class DocFieldDesc: DTYPE_VARCHAR = 0 - def __init__(self, data_type: int = DTYPE_VARCHAR, max_length: Optional[int] = 65535): + def __init__(self, data_type: int, max_length: Optional[int] = None): self.data_type = data_type self.max_length = max_length diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 30015d1b..8a8d6de7 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -47,10 +47,10 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N self._reader = DirectoryReader(None, self._local_file_reader, DocImpl._registered_file_reader) self.node_groups: Dict[str, Dict] = {LAZY_ROOT_NAME: {}} self.embed = {k: embed_wrapper(e) for k, e in embed.items()} - self._embed_dim = None + self._embed_dims = None self._fields_desc = fields_desc self.store = store_conf # NOTE: will be initialized in _lazy_init() - self._activated_embeddings = set() + self._activated_embeddings = {} @once_wrapper(reset_on_pickle=True) def _lazy_init(self) -> None: @@ -59,7 +59,11 @@ def _lazy_init(self) -> None: node_groups.update(self.node_groups) self.node_groups = node_groups - self._embed_dim = {k: len(e('a')) for k, e in self.embed.items()} + # set empty embed keys for groups that are not visited by Retriever + for group in node_groups.keys(): + self._activated_embeddings.setdefault(group, set()) + + self._embed_dims = {k: len(e('a')) for k, e in self.embed.items()} if self.store is None: self.store = { @@ -94,14 +98,13 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase: raise ValueError('`kwargs` in store conf is not a dict.') if store_type == "map": - store = MapStore(node_groups=self.node_groups, embed=self.embed, **kwargs) + store = MapStore(node_groups=self._activated_embeddings.keys(), embed=self.embed, **kwargs) elif store_type == "chroma": - store = ChromadbStore(node_groups=self.node_groups, embed=self.embed, - embed_dim=self.embed_dim, **kwargs) + store = ChromadbStore(group_embed_keys=self._activated_embeddings, embed=self.embed, + embed_dims=self._embed_dims, **kwargs) elif store_type == "milvus": - store = MilvusStore(node_groups=self.node_groups, embed=self.embed, - embed_keys=self._activated_embeddings, - fields_desc=self._fields_desc, **kwargs) + store = MilvusStore(group_embed_keys=self._activated_embeddings, embed=self.embed, + embed_dims=self.embed_dims, fields_desc=self._fields_desc, **kwargs) else: raise NotImplementedError( f"Not implemented store type for {store_type}" diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index 31fa36d8..f208ae25 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Callable +from typing import Dict, List, Optional, Callable, Union, Set from .index_base import IndexBase from .store_base import StoreBase from .doc_node import DocNode @@ -16,7 +16,7 @@ def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str], index.remove(uids, group_name) class MapStore(StoreBase): - def __init__(self, node_groups: List[str], embed: Dict[str, Callable], **kwargs): + def __init__(self, node_groups: Union[List[str], Set[str]], embed: Dict[str, Callable], **kwargs): # Dict[group_name, Dict[uuid, DocNode]] self._group2docs: Dict[str, Dict[str, DocNode]] = { group: {} for group in node_groups diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 451f78bd..b9e9fbaf 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, List, Optional, Union, Callable +from typing import Dict, List, Optional, Union, Callable, Set import pymilvus from pymilvus import MilvusClient from .doc_node import DocNode @@ -42,8 +42,9 @@ class MilvusStore(StoreBase): pymilvus.DataType.VARCHAR, ] - def __init__(self, node_groups: List[str], embed: Dict[str, Callable], embed_keys: List[str], - fields_desc: Dict[str, DocFieldDesc], uri: str, embedding_index_type: Optional[str] = None, + def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Callable], + embed_dims: Dict[str, int], fields_desc: Dict[str, DocFieldDesc], + uri: str, embedding_index_type: Optional[str] = None, embedding_metric_type: Optional[str] = None, **kwargs): self._embed = embed self._client = MilvusClient(uri=uri) @@ -54,14 +55,7 @@ def __init__(self, node_groups: List[str], embed: Dict[str, Callable], embed_key if not embedding_metric_type: embedding_metric_type = 'COSINE' - embed_dims = {} - for k in embed_keys: - e = embed.get(k) - if not e: - raise ValueError(f'cannot find embed callable [{k}]') - embed_dims[k] = len(e('a')) - - for group in node_groups: + for group, embed_keys in group_embed_keys.items(): index_params = self._client.prepare_index_params() schema = self._client.create_schema(auto_id=False, enable_dynamic_field=False) @@ -88,7 +82,7 @@ def __init__(self, node_groups: List[str], embed: Dict[str, Callable], embed_key self._client.create_collection(collection_name=group, schema=schema, index_params=index_params) - self._map_store = MapStore(node_groups=node_groups, embed=embed) + self._map_store = MapStore(node_groups=list(group_embed_keys.keys()), embed=embed) self._load_all_nodes_to(self._map_store) @override diff --git a/lazyllm/tools/rag/retriever.py b/lazyllm/tools/rag/retriever.py index d3568d5d..d919866d 100644 --- a/lazyllm/tools/rag/retriever.py +++ b/lazyllm/tools/rag/retriever.py @@ -50,7 +50,7 @@ def __init__( assert isinstance(doc, Document), 'Only Document or List[Document] are supported' self._submodules.append(doc) if embed_keys: - doc._activated_embeddings.insert(embed_keys) + doc._activated_embeddings.setdefault(group_name, set()).insert(embed_keys) self._group_name = group_name self._similarity = similarity # similarity function str diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index ab78bfe6..a7ef1055 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -33,10 +33,16 @@ def setUp(self): self.mock_embed = { 'default': MagicMock(return_value=[1.0, 2.0, 3.0]), } - self.embed_dim = {"default": 3} + self.embed_dims = {"default": 3} - self.store = ChromadbStore(dir=self.store_dir, node_groups=self.node_groups, - embed=self.mock_embed, embed_dim=self.embed_dim) + embed_keys = set(['default']) + group_embed_keys = { + LAZY_ROOT_NAME: embed_keys, + 'group1': embed_keys, + 'group2': embed_keys, + } + self.store = ChromadbStore(group_embed_keys=group_embed_keys, embed=self.mock_embed, + embed_dims=self.embed_dims, dir=self.store_dir) self.store.update_nodes( [DocNode(uid="1", text="text1", group=LAZY_ROOT_NAME, parent=None)], @@ -74,7 +80,7 @@ def test_load_store(self): # Reset store and load from "persistent" storage self.store._map_store._group2docs = {group: {} for group in self.node_groups} - self.store._load_store(self.embed_dim) + self.store._load_store(self.embed_dims) nodes = self.store.get_nodes("group1") self.assertEqual(len(nodes), 2) @@ -92,7 +98,7 @@ def test_insert_dict_as_sparse_embedding(self): self.store.update_nodes([node1, node2]) results = self.store._peek_all_documents('group1') - nodes = self.store._build_nodes_from_chroma(results, self.embed_dim) + nodes = self.store._build_nodes_from_chroma(results, self.embed_dims) nodes_dict = { node.uid: node for node in nodes } @@ -124,7 +130,6 @@ def setUp(self): self.mock_embed = { 'default': MagicMock(return_value=[1.0, 2.0, 3.0]), } - self.embed_dim = {"default": 3} self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] self.store = MapStore(node_groups=self.node_groups, embed=self.mock_embed) self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None) @@ -184,15 +189,25 @@ def setUp(self): 'vec2': MagicMock(return_value=[400.0, 500.0, 600.0, 700.0, 800.0]), } self.fields_desc = { - 'comment': DocFieldDesc(data_type=DocFieldDesc.DTYPE_VARCHAR), + 'comment': DocFieldDesc(data_type=DocFieldDesc.DTYPE_VARCHAR, max_length=65535), } self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] _, self.store_file = tempfile.mkstemp(suffix=".db") - self.store = MilvusStore(node_groups=self.node_groups, embed=self.mock_embed, - embed_keys=self.mock_embed.keys(), - fields_desc=self.fields_desc, uri=self.store_file) + embed_keys = set(['vec1', 'vec2']) + group_embed_keys = { + LAZY_ROOT_NAME: embed_keys, + 'group1': embed_keys, + 'group2': embed_keys, + } + embed_dims = { + "vec1": 3, + "vec2": 5, + } + self.store = MilvusStore(group_embed_keys=group_embed_keys, embed=self.mock_embed, + embed_dims=embed_dims, fields_desc=self.fields_desc, + uri=self.store_file) self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, embedding={"vec1": [8.0, 9.0, 10.0], "vec2": [11.0, 12.0, 13.0, 14.0, 15.0]}, From 9cc8752b120a648b28c4efef3e6bacdf2f507050 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 7 Nov 2024 11:04:18 +0800 Subject: [PATCH 59/60] review5 --- examples/rag_map_store_with_milvus_index.py | 111 ++++++++++-------- examples/rag_milvus_store.py | 111 ++++++++++-------- lazyllm/tools/rag/__init__.py | 2 +- lazyllm/tools/rag/chroma_store.py | 4 - lazyllm/tools/rag/default_index.py | 78 ++---------- lazyllm/tools/rag/doc_field_desc.py | 6 +- lazyllm/tools/rag/doc_impl.py | 10 +- lazyllm/tools/rag/doc_node.py | 5 +- lazyllm/tools/rag/milvus_store.py | 55 +++++---- lazyllm/tools/rag/retriever.py | 11 +- lazyllm/tools/rag/similarity.py | 53 +++++++++ lazyllm/tools/rag/utils.py | 7 +- .../advanced_tests/full_test/test_example.py | 14 +++ tests/basic_tests/test_index.py | 9 +- 14 files changed, 259 insertions(+), 217 deletions(-) create mode 100644 lazyllm/tools/rag/similarity.py diff --git a/examples/rag_map_store_with_milvus_index.py b/examples/rag_map_store_with_milvus_index.py index b12be866..b0ec8d48 100644 --- a/examples/rag_map_store_with_milvus_index.py +++ b/examples/rag_map_store_with_milvus_index.py @@ -2,59 +2,66 @@ import os import lazyllm +from lazyllm import bind import tempfile -_, store_file = tempfile.mkstemp(suffix=".db") +def run(query): + _, store_file = tempfile.mkstemp(suffix=".db") -milvus_store_conf = { - 'type': 'map', - 'indices': { - 'milvus': { - 'uri': store_file, - 'embedding_index_type': 'HNSW', - 'embedding_metric_type': 'COSINE', + milvus_store_conf = { + 'type': 'map', + 'indices': { + 'milvus': { + 'uri': store_file, + 'embedding_index_type': 'HNSW', + 'embedding_metric_type': 'COSINE', + }, }, - }, -} - -documents = lazyllm.Document(dataset_path="rag_master", - embed=lazyllm.TrainableModule("bge-large-zh-v1.5"), - manager=False, - store_conf=milvus_store_conf) - -documents.create_node_group(name="sentences", - transform=lambda s: '。'.split(s)) - -prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task.'\ - ' In this task, you need to provide your answer based on the given context and question.' - -with lazyllm.pipeline() as ppl: - with lazyllm.parallel().sum as ppl.prl: # noqa F821 - prl.retriever1 = lazyllm.Retriever(doc=documents, # noqa F821 - group_name="CoarseChunk", - similarity="bm25_chinese", - topk=3) - prl.retriever2 = lazyllm.Retriever(doc=documents, # noqa F821 - group_name="sentences", - similarity="cosine", - topk=3) - - ppl.reranker = lazyllm.Reranker(name='ModuleReranker', - model="bge-reranker-large", - topk=1, - output_format='content', - join=True) | bind(query=ppl.input) # noqa F821 - - ppl.formatter = ( - lambda nodes, query: dict(context_str=nodes, query=query) - ) | bind(query=ppl.input) # noqa F821 - - ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt( - lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str'])) - -rag = lazyllm.ActionModule(ppl) -rag.start() - -print("answer: ", rag('who are you?')) - -os.remove(store_file) + } + + documents = lazyllm.Document(dataset_path="rag_master", + embed=lazyllm.TrainableModule("bge-large-zh-v1.5"), + manager=False, + store_conf=milvus_store_conf) + + documents.create_node_group(name="sentences", + transform=lambda s: '。'.split(s)) + + prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task.'\ + ' In this task, you need to provide your answer based on the given context and question.' + + with lazyllm.pipeline() as ppl: + with lazyllm.parallel().sum as ppl.prl: + ppl.prl.retriever1 = lazyllm.Retriever(doc=documents, + group_name="CoarseChunk", + similarity="bm25_chinese", + topk=3) + ppl.prl.retriever2 = lazyllm.Retriever(doc=documents, + group_name="sentences", + similarity="cosine", + topk=3) + + ppl.reranker = lazyllm.Reranker(name='ModuleReranker', + model="bge-reranker-large", + topk=1, + output_format='content', + join=True) | bind(query=ppl.input) + + ppl.formatter = ( + lambda nodes, query: dict(context_str=nodes, query=query) + ) | bind(query=ppl.input) + + ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt( + lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str'])) + + rag = lazyllm.ActionModule(ppl) + rag.start() + res = rag(query) + + os.remove(store_file) + + return res + +if __name__ == '__main__': + res = run('何为天道?') + print(f'answer: {res}') diff --git a/examples/rag_milvus_store.py b/examples/rag_milvus_store.py index c07488b3..a9f6f5f2 100644 --- a/examples/rag_milvus_store.py +++ b/examples/rag_milvus_store.py @@ -2,57 +2,64 @@ import os import lazyllm +from lazyllm import bind import tempfile -_, store_file = tempfile.mkstemp(suffix=".db") - -milvus_store_conf = { - 'type': 'milvus', - 'kwargs': { - 'uri': store_file, - 'embedding_index_type': 'HNSW', - 'embedding_metric_type': 'COSINE', - }, -} - -documents = lazyllm.Document(dataset_path="rag_master", - embed=lazyllm.TrainableModule("bge-large-zh-v1.5"), - manager=False, - store_conf=milvus_store_conf) - -documents.create_node_group(name="sentences", - transform=lambda s: '。'.split(s)) - -prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task.'\ - ' In this task, you need to provide your answer based on the given context and question.' - -with lazyllm.pipeline() as ppl: - with lazyllm.parallel().sum as ppl.prl: - prl.retriever1 = lazyllm.Retriever(doc=documents, # noqa F821 - group_name="CoarseChunk", - similarity="bm25_chinese", - topk=3) - prl.retriever2 = lazyllm.Retriever(doc=documents, # noqa F821 - group_name="sentences", - similarity="cosine", - topk=3) - - ppl.reranker = lazyllm.Reranker(name='ModuleReranker', - model="bge-reranker-large", - topk=1, - output_format='content', - join=True) | bind(query=ppl.input) # noqa F821 - - ppl.formatter = ( - lambda nodes, query: dict(context_str=nodes, query=query) - ) | bind(query=ppl.input) # noqa F821 - - ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt( - lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str'])) - -rag = lazyllm.ActionModule(ppl) -rag.start() - -print("answer: ", rag('who are you?')) - -os.remove(store_file) +def run(query): + _, store_file = tempfile.mkstemp(suffix=".db") + + milvus_store_conf = { + 'type': 'milvus', + 'kwargs': { + 'uri': store_file, + 'embedding_index_type': 'HNSW', + 'embedding_metric_type': 'COSINE', + }, + } + + documents = lazyllm.Document(dataset_path="rag_master", + embed=lazyllm.TrainableModule("bge-large-zh-v1.5"), + manager=False, + store_conf=milvus_store_conf) + + documents.create_node_group(name="sentences", + transform=lambda s: '。'.split(s)) + + prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task.'\ + ' In this task, you need to provide your answer based on the given context and question.' + + with lazyllm.pipeline() as ppl: + with lazyllm.parallel().sum as ppl.prl: + ppl.prl.retriever1 = lazyllm.Retriever(doc=documents, + group_name="CoarseChunk", + similarity="bm25_chinese", + topk=3) + ppl.prl.retriever2 = lazyllm.Retriever(doc=documents, + group_name="sentences", + similarity="cosine", + topk=3) + + ppl.reranker = lazyllm.Reranker(name='ModuleReranker', + model="bge-reranker-large", + topk=1, + output_format='content', + join=True) | bind(query=ppl.input) + + ppl.formatter = ( + lambda nodes, query: dict(context_str=nodes, query=query) + ) | bind(query=ppl.input) + + ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt( + lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str'])) + + rag = lazyllm.ActionModule(ppl) + rag.start() + res = rag(query) + + os.remove(store_file) + + return res + +if __name__ == '__main__': + res = run('何为天道?') + print(f'answer: {res}') diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index 1c4c39ae..f72150f4 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -2,7 +2,7 @@ from .retriever import Retriever from .rerank import Reranker, register_reranker from .transform import SentenceSplitter, LLMParser, NodeTransform, TransformArgs, AdaptiveTransform -from .default_index import register_similarity +from .similarity import register_similarity from .doc_node import DocNode from .readers import (PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader) diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index 0002f870..9c598d2e 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -106,14 +106,11 @@ def _save_nodes(self, nodes: List[DocNode]) -> None: collection ), f"Group {group} is not found in collections {self._collections}" for node in nodes: - if node.is_saved: - continue metadata = self._make_chroma_metadata(node) ids.append(node.uid) embeddings.append([0]) # we don't use chroma for retrieving metadatas.append(metadata) documents.append(node.get_text()) - node.is_saved = True if ids: collection.upsert( embeddings=embeddings, @@ -162,7 +159,6 @@ def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dims: Dict[st new_embedding_dict[key] = embedding node.embedding = new_embedding_dict - node.is_saved = True nodes.append(node) return nodes diff --git a/lazyllm/tools/rag/default_index.py b/lazyllm/tools/rag/default_index.py index fcd4fbfa..1652a41c 100644 --- a/lazyllm/tools/rag/default_index.py +++ b/lazyllm/tools/rag/default_index.py @@ -2,54 +2,18 @@ from .doc_node import DocNode from .store_base import StoreBase from .index_base import IndexBase -import numpy as np -from .component.bm25 import BM25 from lazyllm import LOG from lazyllm.common import override from .utils import parallel_do_embedding +from .similarity import registered_similarities # ---------------------------------------------------------------------------- # class DefaultIndex(IndexBase): - """Default Index, registered for similarity functions""" - - registered_similarity = dict() - def __init__(self, embed: Dict[str, Callable], store: StoreBase, **kwargs): self.embed = embed self.store = store - @classmethod - def register_similarity( - cls: "DefaultIndex", - func: Optional[Callable] = None, - mode: str = "", - descend: bool = True, - batch: bool = False, - ) -> Callable: - def decorator(f): - def wrapper(query, nodes, **kwargs): - if mode != "embedding": - if batch: - return f(query, nodes, **kwargs) - else: - return [(node, f(query, node, **kwargs)) for node in nodes] - else: - assert isinstance(query, dict), "query must be of dict type, used for similarity calculation." - similarity = {} - if batch: - for key, val in query.items(): - nodes_embed = [node.embedding[key] for node in nodes] - similarity[key] = f(val, nodes_embed, **kwargs) - else: - for key, val in query.items(): - similarity[key] = [(node, f(val, node.embedding[key], **kwargs)) for node in nodes] - return similarity - cls.registered_similarity[f.__name__] = (wrapper, mode, descend) - return wrapper - - return decorator(func) if func else decorator - @override def update(self, nodes: List[DocNode]) -> None: pass @@ -69,19 +33,21 @@ def query( embed_keys: Optional[List[str]] = None, **kwargs, ) -> List[DocNode]: - if similarity_name not in self.registered_similarity: + if similarity_name not in registered_similarities: raise ValueError( f"{similarity_name} not registered, please check your input." - f"Available options now: {self.registered_similarity.keys()}" + f"Available options now: {registered_similarities.keys()}" ) - similarity_func, mode, descend = self.registered_similarity[similarity_name] + similarity_func, mode, descend = registered_similarities[similarity_name] nodes = self.store.get_nodes(group_name) if mode == "embedding": assert self.embed, "Chosen similarity needs embed model." assert len(query) > 0, "Query should not be empty." - query_embedding = {k: self.embed[k](query) for k in (embed_keys or self.embed.keys())} - modified_nodes = parallel_do_embedding(self.embed, nodes) + if not embed_keys: + embed_keys = list(self.embed.keys()) + query_embedding = {k: self.embed[k](query) for k in embed_keys} + modified_nodes = parallel_do_embedding(self.embed, embed_keys, nodes) self.store.update_nodes(modified_nodes) similarities = similarity_func(query_embedding, nodes, topk=topk, **kwargs) elif mode == "text": @@ -108,31 +74,3 @@ def _filter_nodes_by_score(self, similarities: List[Tuple[DocNode, float]], topk similarities = similarities[:topk] return [node for node, score in similarities if score > similarity_cut_off] - -@DefaultIndex.register_similarity(mode="text", batch=True) -def bm25(query: str, nodes: List[DocNode], **kwargs) -> List: - bm25_retriever = BM25(nodes, language="en", **kwargs) - return bm25_retriever.retrieve(query) - - -@DefaultIndex.register_similarity(mode="text", batch=True) -def bm25_chinese(query: str, nodes: List[DocNode], **kwargs) -> List: - bm25_retriever = BM25(nodes, language="zh", **kwargs) - return bm25_retriever.retrieve(query) - - -@DefaultIndex.register_similarity(mode="embedding") -def cosine(query: List[float], node: List[float], **kwargs) -> float: - product = np.dot(query, node) - norm = np.linalg.norm(query) * np.linalg.norm(node) - return product / norm - - -# User-defined similarity decorator -def register_similarity( - func: Optional[Callable] = None, - mode: str = "", - descend: bool = True, - batch: bool = False, -) -> Callable: - return DefaultIndex.register_similarity(func, mode, descend, batch) diff --git a/lazyllm/tools/rag/doc_field_desc.py b/lazyllm/tools/rag/doc_field_desc.py index fc71d9a1..bb51d48d 100644 --- a/lazyllm/tools/rag/doc_field_desc.py +++ b/lazyllm/tools/rag/doc_field_desc.py @@ -1,8 +1,10 @@ -from typing import Optional +from typing import Optional, Any class DocFieldDesc: DTYPE_VARCHAR = 0 - def __init__(self, data_type: int, max_length: Optional[int] = None): + def __init__(self, data_type: int, default_value: Optional[Any] = None, + max_length: Optional[int] = None): self.data_type = data_type + self.default_value = default_value self.max_length = max_length diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 8a8d6de7..8088a1c6 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -104,7 +104,7 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase: embed_dims=self._embed_dims, **kwargs) elif store_type == "milvus": store = MilvusStore(group_embed_keys=self._activated_embeddings, embed=self.embed, - embed_dims=self.embed_dims, fields_desc=self._fields_desc, **kwargs) + embed_dims=self._embed_dims, fields_desc=self._fields_desc, **kwargs) else: raise NotImplementedError( f"Not implemented store type for {store_type}" @@ -115,8 +115,12 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase: raise ValueError(f"`indices`'s type [{type(indices_conf)}] is not a dict") for backend_type, kwargs in indices_conf.items(): - index = SmartEmbeddingIndex(backend_type=backend_type, embed=self.embed, - node_groups=self.node_groups, **kwargs) + index = SmartEmbeddingIndex(backend_type=backend_type, + group_embed_keys=self._activated_embeddings, + embed=self.embed, + embed_dims=self._embed_dims, + fields_desc=self._fields_desc, + **kwargs) store.register_index(type=backend_type, index=index) return store diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 6516a3c1..269c205a 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -21,7 +21,7 @@ def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: self.uid: str = uid if uid else str(uuid.uuid4()) self.text: Optional[str] = text self.group: Optional[str] = group - self.embedding: Optional[Dict[str, List[float]]] = embedding or None + self.embedding: Optional[Dict[str, List[float]]] = embedding or {} self._metadata: Dict[str, Any] = metadata or {} # Metadata keys that are excluded from text for the embed model. self._excluded_embed_metadata_keys: List[str] = [] @@ -29,7 +29,6 @@ def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: self._excluded_llm_metadata_keys: List[str] = [] self.parent: Optional["DocNode"] = parent self.children: Dict[str, List["DocNode"]] = defaultdict(list) - self.is_saved: bool = False self._lock = threading.Lock() self._embedding_state = set() @@ -54,7 +53,6 @@ def metadata(self) -> Dict: @metadata.setter def metadata(self, metadata: Dict) -> None: - self.is_saved = False self._metadata = metadata @property @@ -118,7 +116,6 @@ def do_embedding(self, embed: Dict[str, Callable]) -> None: with self._lock: self.embedding = self.embedding or {} self.embedding = {**self.embedding, **generate_embed} - self.is_saved = False def check_embedding_state(self, embed_key: str) -> None: while True: diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index b9e9fbaf..7baa31fb 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -1,7 +1,7 @@ import copy from typing import Dict, List, Optional, Union, Callable, Set import pymilvus -from pymilvus import MilvusClient +from pymilvus import MilvusClient, CollectionSchema, FieldSchema from .doc_node import DocNode from .map_store import MapStore from .utils import parallel_do_embedding @@ -18,26 +18,31 @@ class MilvusStore(StoreBase): _embedding_key_prefix = 'embedding_' _field_key_prefix = 'field_' - _builtin_fields = { + _builtin_keys = { _primary_key: { - 'datatype': pymilvus.DataType.VARCHAR, + 'dtype': pymilvus.DataType.VARCHAR, 'max_length': 256, 'is_primary': True, }, 'parent': { - 'datatype': pymilvus.DataType.VARCHAR, + 'dtype': pymilvus.DataType.VARCHAR, 'max_length': 256, }, 'text': { - 'datatype': pymilvus.DataType.VARCHAR, + 'dtype': pymilvus.DataType.VARCHAR, 'max_length': 65535, }, 'metadata': { - 'datatype': pymilvus.DataType.VARCHAR, + 'dtype': pymilvus.DataType.VARCHAR, 'max_length': 65535, }, } + _builtin_fields_desc = { + 'lazyllm_doc_path': DocFieldDesc(data_type=DocFieldDesc.DTYPE_VARCHAR, + default_value=' ', max_length=65535), + } + _type2milvus = [ pymilvus.DataType.VARCHAR, ] @@ -46,9 +51,14 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla embed_dims: Dict[str, int], fields_desc: Dict[str, DocFieldDesc], uri: str, embedding_index_type: Optional[str] = None, embedding_metric_type: Optional[str] = None, **kwargs): + self._group_embed_keys = group_embed_keys self._embed = embed self._client = MilvusClient(uri=uri) + # XXX milvus 2.4.x doesn't support `default_value` + # https://milvus.io/docs/product_faq.md#Does-Milvus-support-specifying-default-values-for-scalar-or-vector-fields + self._fields_desc = fields_desc | self._builtin_fields_desc + if not embedding_index_type: embedding_index_type = 'HNSW' @@ -56,29 +66,30 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla embedding_metric_type = 'COSINE' for group, embed_keys in group_embed_keys.items(): + field_list = [] index_params = self._client.prepare_index_params() - schema = self._client.create_schema(auto_id=False, enable_dynamic_field=False) - for key, info in self._builtin_fields.items(): - schema.add_field(field_name=key, **info) + for key, info in self._builtin_keys.items(): + field_list.append(FieldSchema(name=key, **info)) for key in embed_keys: dim = embed_dims.get(key) if not dim: - raise ValueError(f'cannot find embedding dim of embed [{key}]') + raise ValueError(f'cannot find embedding dim of embed [{key}] in [{embed_dims}]') field_name = self._gen_embedding_key(key) - schema.add_field(field_name=field_name, datatype=pymilvus.DataType.FLOAT_VECTOR, - dim=dim) + field_list.append(FieldSchema(name=field_name, dtype=pymilvus.DataType.FLOAT_VECTOR, dim=dim)) index_params.add_index(field_name=field_name, index_type=embedding_index_type, metric_type=embedding_metric_type) - if fields_desc: - for key, info in fields_desc.items(): - schema.add_field(field_name=self._gen_field_key(key), - datatype=self._type2milvus[info.data_type], - max_length=info.max_length) + if self._fields_desc: + for key, desc in self._fields_desc.items(): + field_list.append(FieldSchema(name=self._gen_field_key(key), + dtype=self._type2milvus[desc.data_type], + max_length=desc.max_length, + default_value=desc.default_value)) + schema = CollectionSchema(fields=field_list, auto_id=False, enable_dynamic_fields=False) self._client.create_collection(collection_name=group, schema=schema, index_params=index_params) @@ -87,8 +98,10 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla @override def update_nodes(self, nodes: List[DocNode]) -> None: - parallel_do_embedding(self._embed, nodes) for node in nodes: + embed_keys = self._group_embed_keys.get(node.group) + if embed_keys: + parallel_do_embedding(self._embed, embed_keys, [node]) data = self._serialize_node_partial(node) self._client.upsert(collection_name=node.group, data=[data]) @@ -192,8 +205,10 @@ def _serialize_node_partial(self, node: DocNode) -> Dict: for k, v in node.embedding.items(): res[self._gen_embedding_key(k)] = v - for k, v in node.fields.items(): - res[self._gen_field_key(k)] = v + for name, desc in self._fields_desc.items(): + val = node.fields.get(name, desc.default_value) + if val: + res[self._gen_field_key(name)] = val return res diff --git a/lazyllm/tools/rag/retriever.py b/lazyllm/tools/rag/retriever.py index d919866d..b9bff7a4 100644 --- a/lazyllm/tools/rag/retriever.py +++ b/lazyllm/tools/rag/retriever.py @@ -2,6 +2,7 @@ from .doc_node import DocNode from .document import Document, DocImpl from typing import List, Optional, Union, Dict +from .similarity import registered_similarities class _PostProcess(object): def __init__(self, target: Optional[str] = None, @@ -45,12 +46,18 @@ def __init__( ): super().__init__() + _, mode, _ = registered_similarities[similarity] + self._docs: List[Document] = [doc] if isinstance(doc, Document) else doc for doc in self._docs: assert isinstance(doc, Document), 'Only Document or List[Document] are supported' self._submodules.append(doc) - if embed_keys: - doc._activated_embeddings.setdefault(group_name, set()).insert(embed_keys) + if mode == 'embedding' and not embed_keys: + real_embed_keys = list(doc._impl.embed.keys()) + else: + real_embed_keys = embed_keys + if real_embed_keys: + doc._impl._activated_embeddings.setdefault(group_name, set()).update(real_embed_keys) self._group_name = group_name self._similarity = similarity # similarity function str diff --git a/lazyllm/tools/rag/similarity.py b/lazyllm/tools/rag/similarity.py new file mode 100644 index 00000000..89b9be0e --- /dev/null +++ b/lazyllm/tools/rag/similarity.py @@ -0,0 +1,53 @@ +from typing import Optional, Callable, Literal, List +from .component.bm25 import BM25 +import numpy as np +from .doc_node import DocNode + +registered_similarities = dict() + +def register_similarity( + func: Optional[Callable] = None, + mode: Optional[Literal['text', 'embedding']] = None, + descend: bool = True, + batch: bool = False, +) -> Callable: + def decorator(f): + def wrapper(query, nodes, **kwargs): + if mode != "embedding": + if batch: + return f(query, nodes, **kwargs) + else: + return [(node, f(query, node, **kwargs)) for node in nodes] + else: + assert isinstance(query, dict), "query must be of dict type, used for similarity calculation." + similarity = {} + if batch: + for key, val in query.items(): + nodes_embed = [node.embedding[key] for node in nodes] + similarity[key] = f(val, nodes_embed, **kwargs) + else: + for key, val in query.items(): + similarity[key] = [(node, f(val, node.embedding[key], **kwargs)) for node in nodes] + return similarity + registered_similarities[f.__name__] = (wrapper, mode, descend) + return wrapper + + return decorator(func) if func else decorator + +@register_similarity(mode="text", batch=True) +def bm25(query: str, nodes: List[DocNode], **kwargs) -> List: + bm25_retriever = BM25(nodes, language="en", **kwargs) + return bm25_retriever.retrieve(query) + + +@register_similarity(mode="text", batch=True) +def bm25_chinese(query: str, nodes: List[DocNode], **kwargs) -> List: + bm25_retriever = BM25(nodes, language="zh", **kwargs) + return bm25_retriever.retrieve(query) + + +@register_similarity(mode="embedding") +def cosine(query: List[float], node: List[float], **kwargs) -> float: + product = np.dot(query, node) + norm = np.linalg.norm(query) * np.linalg.norm(node) + return product / norm diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index e195c5ea..deb60a29 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -2,7 +2,7 @@ import shutil import hashlib import concurrent -from typing import List, Callable, Generator, Dict, Any, Optional, Union, Tuple +from typing import List, Callable, Generator, Dict, Any, Optional, Union, Tuple, Set from abc import ABC, abstractmethod from .index_base import IndexBase from .store_base import LAZY_ROOT_NAME @@ -486,12 +486,13 @@ def save_files_in_threads( return (already_exist_files, new_add_files, overwritten_files) # returns a list of modified nodes -def parallel_do_embedding(embed: Dict[str, Callable], nodes: List[DocNode]) -> List[DocNode]: +def parallel_do_embedding(embed: Dict[str, Callable], embed_keys: Optional[Union[List[str], Set[str]]], + nodes: List[DocNode]) -> List[DocNode]: modified_nodes = [] with ThreadPoolExecutor(config["max_embedding_workers"]) as executor: futures = [] for node in nodes: - miss_keys = node.has_missing_embedding(embed.keys()) + miss_keys = node.has_missing_embedding(embed_keys) if not miss_keys: continue modified_nodes.append(node) diff --git a/tests/advanced_tests/full_test/test_example.py b/tests/advanced_tests/full_test/test_example.py index 1f000358..796172bc 100644 --- a/tests/advanced_tests/full_test/test_example.py +++ b/tests/advanced_tests/full_test/test_example.py @@ -139,3 +139,17 @@ def test_painting(self): api_name="/_respond_stream") image_path = ans[0][0][-1]['value'] assert os.path.isfile(image_path) + + def test_rag_milvus_store(self): + from examples.rag_milvus_store import run as rag_run + res = rag_run('何为天道?') + assert type(res) is str + assert "天道" in res + assert len(res) >= 16 + + def test_rag_map_store_with_milvus_index(self): + from examples.rag_map_store_with_milvus_index import run as rag_run + res = rag_run('何为天道?') + assert type(res) is str + assert "天道" in res + assert len(res) >= 16 diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index 107d1bc6..5b735e49 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -3,7 +3,8 @@ from unittest.mock import MagicMock from lazyllm.tools.rag.map_store import MapStore from lazyllm.tools.rag.doc_node import DocNode -from lazyllm.tools.rag.default_index import DefaultIndex, register_similarity +from lazyllm.tools.rag.default_index import DefaultIndex +from lazyllm.tools.rag.similarity import register_similarity, registered_similarities from lazyllm.tools.rag.utils import parallel_do_embedding class TestDefaultIndex(unittest.TestCase): @@ -38,9 +39,9 @@ def test_register_similarity(self): def custom_similarity(query, nodes, **kwargs): return [(node, 1.0) for node in nodes] - self.assertIn("custom_similarity", DefaultIndex.registered_similarity) + self.assertIn("custom_similarity", registered_similarities) self.assertEqual( - DefaultIndex.registered_similarity["custom_similarity"][1], "embedding" + registered_similarities["custom_similarity"][1], "embedding" ) def test_query_cosine_similarity(self): @@ -71,7 +72,7 @@ def test_parallel_do_embedding(self): for node in self.nodes: node.has_embedding = MagicMock(return_value=False) start_time = time.time() - parallel_do_embedding(self.index.embed, self.nodes) + parallel_do_embedding(self.index.embed, self.index.embed.keys(), self.nodes) assert time.time() - start_time < 4, "Parallel not used!" def test_query_multi_embed_similarity(self): From b724bfc8382e02ee190156ed0d87bf3f3c48c4f0 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Fri, 8 Nov 2024 11:17:24 +0800 Subject: [PATCH 60/60] review6 --- examples/rag_map_store_with_milvus_index.py | 11 +++++---- lazyllm/tools/rag/doc_impl.py | 26 ++++++++++++++------- lazyllm/tools/rag/milvus_store.py | 5 +++- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/examples/rag_map_store_with_milvus_index.py b/examples/rag_map_store_with_milvus_index.py index b0ec8d48..25ec98a2 100644 --- a/examples/rag_map_store_with_milvus_index.py +++ b/examples/rag_map_store_with_milvus_index.py @@ -11,10 +11,13 @@ def run(query): milvus_store_conf = { 'type': 'map', 'indices': { - 'milvus': { - 'uri': store_file, - 'embedding_index_type': 'HNSW', - 'embedding_metric_type': 'COSINE', + 'smart_embedding_index': { + 'backend': 'milvus', + 'kwargs': { + 'uri': store_file, + 'embedding_index_type': 'HNSW', + 'embedding_metric_type': 'COSINE', + }, }, }, } diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 8088a1c6..aa347b26 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -98,7 +98,7 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase: raise ValueError('`kwargs` in store conf is not a dict.') if store_type == "map": - store = MapStore(node_groups=self._activated_embeddings.keys(), embed=self.embed, **kwargs) + store = MapStore(node_groups=list(self._activated_embeddings.keys()), embed=self.embed, **kwargs) elif store_type == "chroma": store = ChromadbStore(group_embed_keys=self._activated_embeddings, embed=self.embed, embed_dims=self._embed_dims, **kwargs) @@ -114,14 +114,22 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase: if not isinstance(indices_conf, Dict): raise ValueError(f"`indices`'s type [{type(indices_conf)}] is not a dict") - for backend_type, kwargs in indices_conf.items(): - index = SmartEmbeddingIndex(backend_type=backend_type, - group_embed_keys=self._activated_embeddings, - embed=self.embed, - embed_dims=self._embed_dims, - fields_desc=self._fields_desc, - **kwargs) - store.register_index(type=backend_type, index=index) + for index_type, conf in indices_conf.items(): + if index_type == 'smart_embedding_index': + backend_type = conf.get('backend') + if not backend_type: + raise ValueError('`backend` is not specified in `smart_embedding_index`.') + kwargs = conf.get('kwargs', {}) + index = SmartEmbeddingIndex(backend_type=backend_type, + group_embed_keys=self._activated_embeddings, + embed=self.embed, + embed_dims=self._embed_dims, + fields_desc=self._fields_desc, + **kwargs) + else: + raise ValueError(f'unsupported index type [{index_type}]') + + store.register_index(type=index_type, index=index) return store diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 7baa31fb..faf5eec2 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -57,7 +57,10 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla # XXX milvus 2.4.x doesn't support `default_value` # https://milvus.io/docs/product_faq.md#Does-Milvus-support-specifying-default-values-for-scalar-or-vector-fields - self._fields_desc = fields_desc | self._builtin_fields_desc + if fields_desc: + self._fields_desc = fields_desc | self._builtin_fields_desc + else: + self._fields_desc = self._builtin_fields_desc if not embedding_index_type: embedding_index_type = 'HNSW'