Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfixes: variable assignments fix and milvus throws an error when text similarity is used #339

Merged
merged 4 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions examples/rag_map_store_with_milvus_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,8 @@ def run(query):
' In this task, you need to provide your answer based on the given context and question.'

with lazyllm.pipeline() as ppl:
with lazyllm.parallel().sum as ppl.prl:
ppl.prl.retriever1 = lazyllm.Retriever(doc=documents,
group_name="CoarseChunk",
similarity="bm25_chinese",
topk=3)
ppl.prl.retriever2 = lazyllm.Retriever(doc=documents,
group_name="sentences",
similarity="cosine",
topk=3)
ppl.retriever = lazyllm.Retriever(doc=documents, group_name="sentences", topk=3,
index='smart_embedding_index')

ppl.reranker = lazyllm.Reranker(name='ModuleReranker',
model="bge-reranker-large",
Expand Down
10 changes: 1 addition & 9 deletions examples/rag_milvus_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,7 @@ def run(query):
' In this task, you need to provide your answer based on the given context and question.'

with lazyllm.pipeline() as ppl:
with lazyllm.parallel().sum as ppl.prl:
ppl.prl.retriever1 = lazyllm.Retriever(doc=documents,
group_name="CoarseChunk",
similarity="bm25_chinese",
topk=3)
ppl.prl.retriever2 = lazyllm.Retriever(doc=documents,
group_name="sentences",
similarity="cosine",
topk=3)
ppl.retriever = lazyllm.Retriever(doc=documents, group_name="sentences", topk=3)

ppl.reranker = lazyllm.Reranker(name='ModuleReranker',
model="bge-reranker-large",
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/thirdparty/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ def __getattribute__(self, __name):

modules = ['redis', 'huggingface_hub', 'jieba', 'modelscope', 'pandas', 'jwt', 'rank_bm25', 'redisvl', 'datasets',
'deepspeed', 'fire', 'numpy', 'peft', 'torch', 'transformers', 'collie', 'faiss', 'flash_attn', 'google',
'lightllm', 'vllm', 'ChatTTS', 'wandb', 'funasr', 'sklearn', 'torchvision', 'scipy']
'lightllm', 'vllm', 'ChatTTS', 'wandb', 'funasr', 'sklearn', 'torchvision', 'scipy', 'pymilvus']
for m in modules:
vars()[m] = PackageWrapper(m)
5 changes: 3 additions & 2 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,9 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_
index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]:
self._lazy_init()

if type is None or type == 'default':
self._dynamic_create_nodes(group_name, self.store)

if index is None or index == 'default':
return self.store.query(query=query, group_name=group_name, similarity_name=similarity,
similarity_cut_off=similarity_cut_off, topk=topk,
embed_keys=embed_keys, **similarity_kws)
Expand All @@ -304,7 +306,6 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_
if not index_instance:
raise NotImplementedError(f"index type '{index}' is not supported currently.")

self._dynamic_create_nodes(group_name, self.store)
return index_instance.query(query=query, group_name=group_name, similarity_name=similarity,
similarity_cut_off=similarity_cut_off, topk=topk,
embed_keys=embed_keys, **similarity_kws)
Expand Down
28 changes: 15 additions & 13 deletions lazyllm/tools/rag/milvus_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
from typing import Dict, List, Optional, Union, Callable, Set
import pymilvus
from pymilvus import MilvusClient, CollectionSchema, FieldSchema
from lazyllm.thirdparty import pymilvus
from .doc_node import DocNode
from .map_store import MapStore
from .utils import parallel_do_embedding
Expand Down Expand Up @@ -53,7 +52,7 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla
embedding_metric_type: Optional[str] = None, **kwargs):
self._group_embed_keys = group_embed_keys
self._embed = embed
self._client = MilvusClient(uri=uri)
self._client = pymilvus.MilvusClient(uri=uri)

# XXX milvus 2.4.x doesn't support `default_value`
# https://milvus.io/docs/product_faq.md#Does-Milvus-support-specifying-default-values-for-scalar-or-vector-fields
Expand All @@ -73,26 +72,26 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla
index_params = self._client.prepare_index_params()

for key, info in self._builtin_keys.items():
field_list.append(FieldSchema(name=key, **info))
field_list.append(pymilvus.FieldSchema(name=key, **info))

for key in embed_keys:
dim = embed_dims.get(key)
if not dim:
raise ValueError(f'cannot find embedding dim of embed [{key}] in [{embed_dims}]')

field_name = self._gen_embedding_key(key)
field_list.append(FieldSchema(name=field_name, dtype=pymilvus.DataType.FLOAT_VECTOR, dim=dim))
field_list.append(pymilvus.FieldSchema(name=field_name, dtype=pymilvus.DataType.FLOAT_VECTOR, dim=dim))
index_params.add_index(field_name=field_name, index_type=embedding_index_type,
metric_type=embedding_metric_type)

if self._fields_desc:
for key, desc in self._fields_desc.items():
field_list.append(FieldSchema(name=self._gen_field_key(key),
dtype=self._type2milvus[desc.data_type],
max_length=desc.max_length,
default_value=desc.default_value))
field_list.append(pymilvus.FieldSchema(name=self._gen_field_key(key),
dtype=self._type2milvus[desc.data_type],
max_length=desc.max_length,
default_value=desc.default_value))

schema = CollectionSchema(fields=field_list, auto_id=False, enable_dynamic_fields=False)
schema = pymilvus.CollectionSchema(fields=field_list, auto_id=False, enable_dynamic_fields=False)
self._client.create_collection(collection_name=group, schema=schema,
index_params=index_params)

Expand Down Expand Up @@ -146,14 +145,17 @@ def get_index(self, type: Optional[str] = None) -> Optional[IndexBase]:
def query(self,
query: str,
group_name: str,
similarity: Optional[str] = None,
similarity_name: Optional[str] = None,
similarity_cut_off: Optional[Union[float, Dict[str, float]]] = None,
topk: int = 10,
embed_keys: Optional[List[str]] = None,
**kwargs) -> List[DocNode]:
if similarity is not None:
if similarity_name is not None:
raise ValueError('`similarity` MUST be None when Milvus backend is used.')

if not embed_keys:
raise ValueError('empty or None `embed_keys` is not supported.')

uidset = set()
for key in embed_keys:
embed_func = self._embed.get(key)
Expand All @@ -165,7 +167,7 @@ def query(self,
raise ValueError(f'number of results [{len(results)}] != expected [1]')

for result in results[0]:
uidset.update(result['id'])
uidset.add(result['id'])

return self._map_store.get_nodes(group_name, list(uidset))

Expand Down
13 changes: 7 additions & 6 deletions lazyllm/tools/rag/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,19 @@ def __init__(
):
super().__init__()

_, mode, _ = registered_similarities[similarity]
if similarity:
_, mode, _ = registered_similarities[similarity]
else:
mode = 'embedding' # TODO FIXME XXX should be removed after similarity args refactor

self._docs: List[Document] = [doc] if isinstance(doc, Document) else doc
for doc in self._docs:
assert isinstance(doc, Document), 'Only Document or List[Document] are supported'
self._submodules.append(doc)
if mode == 'embedding' and not embed_keys:
real_embed_keys = list(doc._impl.embed.keys())
else:
real_embed_keys = embed_keys
if real_embed_keys:
doc._impl._activated_embeddings.setdefault(group_name, set()).update(real_embed_keys)
embed_keys = list(doc._impl.embed.keys())
if embed_keys:
doc._impl._activated_embeddings.setdefault(group_name, set()).update(embed_keys)

self._group_name = group_name
self._similarity = similarity # similarity function str
Expand Down
Loading