diff --git a/examples/rag_map_store_with_milvus_index.py b/examples/rag_map_store_with_milvus_index.py index 25ec98a2..50834dd3 100644 --- a/examples/rag_map_store_with_milvus_index.py +++ b/examples/rag_map_store_with_milvus_index.py @@ -34,15 +34,8 @@ def run(query): ' In this task, you need to provide your answer based on the given context and question.' with lazyllm.pipeline() as ppl: - with lazyllm.parallel().sum as ppl.prl: - ppl.prl.retriever1 = lazyllm.Retriever(doc=documents, - group_name="CoarseChunk", - similarity="bm25_chinese", - topk=3) - ppl.prl.retriever2 = lazyllm.Retriever(doc=documents, - group_name="sentences", - similarity="cosine", - topk=3) + ppl.retriever = lazyllm.Retriever(doc=documents, group_name="sentences", topk=3, + index='smart_embedding_index') ppl.reranker = lazyllm.Reranker(name='ModuleReranker', model="bge-reranker-large", diff --git a/examples/rag_milvus_store.py b/examples/rag_milvus_store.py index a9f6f5f2..a7391746 100644 --- a/examples/rag_milvus_store.py +++ b/examples/rag_milvus_store.py @@ -29,15 +29,7 @@ def run(query): ' In this task, you need to provide your answer based on the given context and question.' with lazyllm.pipeline() as ppl: - with lazyllm.parallel().sum as ppl.prl: - ppl.prl.retriever1 = lazyllm.Retriever(doc=documents, - group_name="CoarseChunk", - similarity="bm25_chinese", - topk=3) - ppl.prl.retriever2 = lazyllm.Retriever(doc=documents, - group_name="sentences", - similarity="cosine", - topk=3) + ppl.retriever = lazyllm.Retriever(doc=documents, group_name="sentences", topk=3) ppl.reranker = lazyllm.Reranker(name='ModuleReranker', model="bge-reranker-large", diff --git a/lazyllm/thirdparty/__init__.py b/lazyllm/thirdparty/__init__.py index 08b90f6f..f92eb85b 100644 --- a/lazyllm/thirdparty/__init__.py +++ b/lazyllm/thirdparty/__init__.py @@ -79,6 +79,6 @@ def __getattribute__(self, __name): modules = ['redis', 'huggingface_hub', 'jieba', 'modelscope', 'pandas', 'jwt', 'rank_bm25', 'redisvl', 'datasets', 'deepspeed', 'fire', 'numpy', 'peft', 'torch', 'transformers', 'collie', 'faiss', 'flash_attn', 'google', - 'lightllm', 'vllm', 'ChatTTS', 'wandb', 'funasr', 'sklearn', 'torchvision', 'scipy'] + 'lightllm', 'vllm', 'ChatTTS', 'wandb', 'funasr', 'sklearn', 'torchvision', 'scipy', 'pymilvus'] for m in modules: vars()[m] = PackageWrapper(m) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index aa347b26..3397baae 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -295,7 +295,9 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_ index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]: self._lazy_init() - if type is None or type == 'default': + self._dynamic_create_nodes(group_name, self.store) + + if index is None or index == 'default': return self.store.query(query=query, group_name=group_name, similarity_name=similarity, similarity_cut_off=similarity_cut_off, topk=topk, embed_keys=embed_keys, **similarity_kws) @@ -304,7 +306,6 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_ if not index_instance: raise NotImplementedError(f"index type '{index}' is not supported currently.") - self._dynamic_create_nodes(group_name, self.store) return index_instance.query(query=query, group_name=group_name, similarity_name=similarity, similarity_cut_off=similarity_cut_off, topk=topk, embed_keys=embed_keys, **similarity_kws) diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index faf5eec2..42200a22 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -1,7 +1,6 @@ import copy from typing import Dict, List, Optional, Union, Callable, Set -import pymilvus -from pymilvus import MilvusClient, CollectionSchema, FieldSchema +from lazyllm.thirdparty import pymilvus from .doc_node import DocNode from .map_store import MapStore from .utils import parallel_do_embedding @@ -53,7 +52,7 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla embedding_metric_type: Optional[str] = None, **kwargs): self._group_embed_keys = group_embed_keys self._embed = embed - self._client = MilvusClient(uri=uri) + self._client = pymilvus.MilvusClient(uri=uri) # XXX milvus 2.4.x doesn't support `default_value` # https://milvus.io/docs/product_faq.md#Does-Milvus-support-specifying-default-values-for-scalar-or-vector-fields @@ -73,7 +72,7 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla index_params = self._client.prepare_index_params() for key, info in self._builtin_keys.items(): - field_list.append(FieldSchema(name=key, **info)) + field_list.append(pymilvus.FieldSchema(name=key, **info)) for key in embed_keys: dim = embed_dims.get(key) @@ -81,18 +80,18 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla raise ValueError(f'cannot find embedding dim of embed [{key}] in [{embed_dims}]') field_name = self._gen_embedding_key(key) - field_list.append(FieldSchema(name=field_name, dtype=pymilvus.DataType.FLOAT_VECTOR, dim=dim)) + field_list.append(pymilvus.FieldSchema(name=field_name, dtype=pymilvus.DataType.FLOAT_VECTOR, dim=dim)) index_params.add_index(field_name=field_name, index_type=embedding_index_type, metric_type=embedding_metric_type) if self._fields_desc: for key, desc in self._fields_desc.items(): - field_list.append(FieldSchema(name=self._gen_field_key(key), - dtype=self._type2milvus[desc.data_type], - max_length=desc.max_length, - default_value=desc.default_value)) + field_list.append(pymilvus.FieldSchema(name=self._gen_field_key(key), + dtype=self._type2milvus[desc.data_type], + max_length=desc.max_length, + default_value=desc.default_value)) - schema = CollectionSchema(fields=field_list, auto_id=False, enable_dynamic_fields=False) + schema = pymilvus.CollectionSchema(fields=field_list, auto_id=False, enable_dynamic_fields=False) self._client.create_collection(collection_name=group, schema=schema, index_params=index_params) @@ -146,14 +145,17 @@ def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]: def query(self, query: str, group_name: str, - similarity: Optional[str] = None, + similarity_name: Optional[str] = None, similarity_cut_off: Optional[Union[float, Dict[str, float]]] = None, topk: int = 10, embed_keys: Optional[List[str]] = None, **kwargs) -> List[DocNode]: - if similarity is not None: + if similarity_name is not None: raise ValueError('`similarity` MUST be None when Milvus backend is used.') + if not embed_keys: + raise ValueError('empty or None `embed_keys` is not supported.') + uidset = set() for key in embed_keys: embed_func = self._embed.get(key) @@ -165,7 +167,7 @@ def query(self, raise ValueError(f'number of results [{len(results)}] != expected [1]') for result in results[0]: - uidset.update(result['id']) + uidset.add(result['id']) return self._map_store.get_nodes(group_name, list(uidset)) diff --git a/lazyllm/tools/rag/retriever.py b/lazyllm/tools/rag/retriever.py index b9bff7a4..1dc4f496 100644 --- a/lazyllm/tools/rag/retriever.py +++ b/lazyllm/tools/rag/retriever.py @@ -46,18 +46,19 @@ def __init__( ): super().__init__() - _, mode, _ = registered_similarities[similarity] + if similarity: + _, mode, _ = registered_similarities[similarity] + else: + mode = 'embedding' # TODO FIXME XXX should be removed after similarity args refactor self._docs: List[Document] = [doc] if isinstance(doc, Document) else doc for doc in self._docs: assert isinstance(doc, Document), 'Only Document or List[Document] are supported' self._submodules.append(doc) if mode == 'embedding' and not embed_keys: - real_embed_keys = list(doc._impl.embed.keys()) - else: - real_embed_keys = embed_keys - if real_embed_keys: - doc._impl._activated_embeddings.setdefault(group_name, set()).update(real_embed_keys) + embed_keys = list(doc._impl.embed.keys()) + if embed_keys: + doc._impl._activated_embeddings.setdefault(group_name, set()).update(embed_keys) self._group_name = group_name self._similarity = similarity # similarity function str