Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sparse embedding support for milvus #379

Merged
merged 19 commits into from
Dec 18, 2024
6 changes: 3 additions & 3 deletions examples/rag_milvus_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import lazyllm
from lazyllm import bind, config
from lazyllm.tools.rag import DocField
from lazyllm.tools.rag import DocField, DataType
import shutil

class TmpDir:
Expand All @@ -28,8 +28,8 @@ def __del__(self):
}

doc_fields = {
'comment': DocField(data_type=DocField.DTYPE_VARCHAR, max_size=65535, default_value=' '),
'signature': DocField(data_type=DocField.DTYPE_VARCHAR, max_size=32, default_value=' '),
'comment': DocField(data_type=DataType.VARCHAR, max_size=65535, default_value=' '),
'signature': DocField(data_type=DataType.VARCHAR, max_size=32, default_value=' '),
}

prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task.'\
Expand Down
2 changes: 2 additions & 0 deletions lazyllm/tools/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .dataReader import SimpleDirectoryReader
from .doc_manager import DocManager, DocListManager
from .global_metadata import GlobalMetadataDesc as DocField
from .data_type import DataType


__all__ = [
Expand Down Expand Up @@ -39,4 +40,5 @@
'DocManager',
'DocListManager',
'DocField',
'DataType',
]
7 changes: 2 additions & 5 deletions lazyllm/tools/rag/chroma_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .store_base import StoreBase, LAZY_ROOT_NAME
from .doc_node import DocNode
from .index_base import IndexBase
from .utils import _FileNodeIndex
from .utils import _FileNodeIndex, sparse2normal
from .default_index import DefaultIndex
from .map_store import MapStore

Expand Down Expand Up @@ -151,10 +151,7 @@ def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dims: Dict[st
dim = embed_dims.get(key)
if not dim:
raise ValueError(f'dim of embed [{key}] is not determined.')
new_embedding = [0] * dim
for idx, val in embedding.items():
new_embedding[int(idx)] = val
new_embedding_dict[key] = new_embedding
new_embedding_dict[key] = sparse2normal(embedding, dim)
else:
new_embedding_dict[key] = embedding
node.embedding = new_embedding_dict
Expand Down
8 changes: 8 additions & 0 deletions lazyllm/tools/rag/data_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import IntEnum

class DataType(IntEnum):
VARCHAR = 0
ARRAY = 1
INT32 = 2
FLOAT_VECTOR = 3
SPARSE_FLOAT_VECTOR = 4
11 changes: 9 additions & 2 deletions lazyllm/tools/rag/default_index.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import List, Callable, Optional, Dict, Union, Tuple
from typing import List, Callable, Optional, Dict, Union, Tuple, Any
from .doc_node import DocNode
from .store_base import StoreBase
from .index_base import IndexBase
from lazyllm import LOG
from lazyllm.common import override
from .utils import parallel_do_embedding, generic_process_filters
from .utils import parallel_do_embedding, generic_process_filters, is_sparse
from .similarity import registered_similarities

# ---------------------------------------------------------------------------- #
Expand Down Expand Up @@ -51,6 +51,7 @@ def query(
if not embed_keys:
embed_keys = list(self.embed.keys())
query_embedding = {k: self.embed[k](query) for k in embed_keys}
self._check_supported(similarity_name, query_embedding)
modified_nodes = parallel_do_embedding(self.embed, embed_keys, nodes)
self.store.update_nodes(modified_nodes)
similarities = similarity_func(query_embedding, nodes, topk=topk, **kwargs)
Expand Down Expand Up @@ -78,3 +79,9 @@ def _filter_nodes_by_score(self, similarities: List[Tuple[DocNode, float]], topk
similarities = similarities[:topk]

return [node for node, score in similarities if score > similarity_cut_off]

def _check_supported(self, similarity_name: str, query_embedding: Dict[str, Any]) -> None:
if similarity_name.lower() == 'cosine':
for k, e in query_embedding.items():
if is_sparse(e):
raise NotImplementedError(f'embed `{k}` which is sparse is not supported.')
ouonline marked this conversation as resolved.
Show resolved Hide resolved
29 changes: 20 additions & 9 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from .smart_embedding_index import SmartEmbeddingIndex
from .doc_node import DocNode
from .data_loaders import DirectoryReader
from .utils import DocListManager, gen_docid
from .utils import DocListManager, gen_docid, is_sparse
from .global_metadata import GlobalMetadataDesc, RAG_DOC_ID, RAG_DOC_PATH
from .data_type import DataType
import threading
import time

Expand Down Expand Up @@ -49,7 +50,6 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N
self._reader = DirectoryReader(None, self._local_file_reader, DocImpl._registered_file_reader)
self.node_groups: Dict[str, Dict] = {LAZY_ROOT_NAME: {}}
self.embed = {k: embed_wrapper(e) for k, e in embed.items()}
self._embed_dims = None
self._global_metadata_desc = global_metadata_desc
self.store = store_conf # NOTE: will be initialized in _lazy_init()
self._activated_embeddings = {}
Expand All @@ -65,15 +65,24 @@ def _lazy_init(self) -> None:
for group in node_groups.keys():
self._activated_embeddings.setdefault(group, set())

self._embed_dims = {k: len(e('a')) for k, e in self.embed.items()}
embed_dims = {}
embed_datatypes = {}
for k, e in self.embed.items():
embedding = e('a')
if is_sparse(embedding):
embed_datatypes[k] = DataType.SPARSE_FLOAT_VECTOR
else:
embed_dims[k] = len(embedding)
embed_datatypes[k] = DataType.FLOAT_VECTOR

if self.store is None:
self.store = {
'type': 'map',
}

if isinstance(self.store, Dict):
self.store = self._create_store(self.store)
self.store = self._create_store(store_conf=self.store, embed_dims=embed_dims,
embed_datatypes=embed_datatypes)
else:
raise ValueError(f'store type [{type(self.store)}] is not a dict.')

Expand All @@ -95,7 +104,8 @@ def _lazy_init(self) -> None:
self._daemon.daemon = True
self._daemon.start()

def _create_store(self, store_conf: Optional[Dict]) -> StoreBase:
def _create_store(self, store_conf: Optional[Dict], embed_dims: Optional[Dict[str, int]] = None,
embed_datatypes: Optional[Dict[str, DataType]] = None) -> StoreBase:
store_type = store_conf.get('type')
if not store_type:
raise ValueError('store type is not specified.')
Expand All @@ -108,11 +118,11 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase:
store = MapStore(node_groups=list(self._activated_embeddings.keys()), embed=self.embed, **kwargs)
elif store_type == "chroma":
store = ChromadbStore(group_embed_keys=self._activated_embeddings, embed=self.embed,
embed_dims=self._embed_dims, **kwargs)
embed_dims=embed_dims, **kwargs)
elif store_type == "milvus":
store = MilvusStore(group_embed_keys=self._activated_embeddings, embed=self.embed,
embed_dims=self._embed_dims, global_metadata_desc=self._global_metadata_desc,
**kwargs)
embed_dims=embed_dims, embed_datatypes=embed_datatypes,
global_metadata_desc=self._global_metadata_desc, **kwargs)
else:
raise NotImplementedError(
f"Not implemented store type for {store_type}"
Expand All @@ -131,7 +141,8 @@ def _create_store(self, store_conf: Optional[Dict]) -> StoreBase:
index = SmartEmbeddingIndex(backend_type=backend_type,
group_embed_keys=self._activated_embeddings,
embed=self.embed,
embed_dims=self._embed_dims,
embed_dims=embed_dims,
embed_datatypes=embed_datatypes,
global_metadata_desc=self._global_metadata_desc,
**kwargs)
else:
Expand Down
6 changes: 1 addition & 5 deletions lazyllm/tools/rag/global_metadata.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import Optional, Any

class GlobalMetadataDesc:
DTYPE_VARCHAR = 0
DTYPE_ARRAY = 1
DTYPE_INT32 = 2

# max_size MUST be set when data_type is DTYPE_VARCHAR or DTYPE_ARRAY
# max_size MUST be set when data_type is DataType.VARCHAR or DataType.ARRAY
def __init__(self, data_type: int, element_type: Optional[int] = None,
default_value: Optional[Any] = None, max_size: Optional[int] = None):
self.data_type = data_type
Expand Down
37 changes: 26 additions & 11 deletions lazyllm/tools/rag/milvus_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .index_base import IndexBase
from .store_base import StoreBase
from .global_metadata import GlobalMetadataDesc, RAG_DOC_PATH, RAG_DOC_ID
from .data_type import DataType
from lazyllm.common import override, obj2str, str2obj

class MilvusStore(StoreBase):
Expand Down Expand Up @@ -38,20 +39,23 @@ def _def_constants(self) -> None:
}

self._builtin_global_metadata_desc = {
RAG_DOC_ID: GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_VARCHAR,
RAG_DOC_ID: GlobalMetadataDesc(data_type=DataType.VARCHAR,
default_value=' ', max_size=512),
RAG_DOC_PATH: GlobalMetadataDesc(data_type=GlobalMetadataDesc.DTYPE_VARCHAR,
RAG_DOC_PATH: GlobalMetadataDesc(data_type=DataType.VARCHAR,
default_value=' ', max_size=65535),
}

self._type2milvus = [
pymilvus.DataType.VARCHAR,
pymilvus.DataType.ARRAY,
pymilvus.DataType.INT32,
pymilvus.DataType.FLOAT_VECTOR,
pymilvus.DataType.SPARSE_FLOAT_VECTOR,
]

def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Callable], # noqa C901
embed_dims: Dict[str, int], global_metadata_desc: Dict[str, GlobalMetadataDesc],
embed_dims: Dict[str, int], embed_datatypes: Dict[str, DataType],
global_metadata_desc: Dict[str, GlobalMetadataDesc],
uri: str, embedding_index_type: Optional[str] = None,
embedding_metric_type: Optional[str] = None, **kwargs):
self._def_constants()
Expand All @@ -60,6 +64,11 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla
self._embed = embed
self._client = pymilvus.MilvusClient(uri=uri)

if embed_dims is None:
embed_dims = {}
if embed_datatypes is None:
embed_datatypes = {}

# XXX milvus 2.4.x doesn't support `default_value`
# https://milvus.io/docs/product_faq.md#Does-Milvus-support-specifying-default-values-for-scalar-or-vector-fields
if global_metadata_desc:
Expand All @@ -85,26 +94,32 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla
field_list.append(pymilvus.FieldSchema(name=key, **info))

for key in embed_keys:
dim = embed_dims.get(key)
if not dim:
raise ValueError(f'cannot find embedding dim of embed [{key}] in [{embed_dims}]')
datatype = embed_datatypes.get(key)
if not datatype:
raise ValueError(f'cannot find embedding datatype if embed [{key}] in [{embed_datatypes}]')

field_kwargs = {}
dim = embed_dims.get(key) # can be empty if embedding is sparse
if dim:
field_kwargs['dim'] = dim

field_name = self._gen_embedding_key(key)
field_list.append(pymilvus.FieldSchema(name=field_name, dtype=pymilvus.DataType.FLOAT_VECTOR, dim=dim))
field_list.append(pymilvus.FieldSchema(name=field_name, dtype=self._type2milvus[datatype],
**field_kwargs))
index_params.add_index(field_name=field_name, index_type=embedding_index_type,
metric_type=embedding_metric_type)

if self._global_metadata_desc:
for key, desc in self._global_metadata_desc.items():
if desc.data_type == GlobalMetadataDesc.DTYPE_ARRAY:
if desc.data_type == DataType.ARRAY:
if not desc.element_type:
raise ValueError(f'Milvus field [{key}]: `element_type` is required when '
'`data_type` is DTYPE_ARRAY.')
'`data_type` is ARRAY.')
field_args = {
'element_type': self._type2milvus[desc.element_type],
'max_capacity': desc.max_size,
}
elif desc.data_type == GlobalMetadataDesc.DTYPE_VARCHAR:
elif desc.data_type == DataType.VARCHAR:
field_args = {
'max_length': desc.max_size,
}
Expand Down Expand Up @@ -236,7 +251,7 @@ def _construct_filter_expr(self, filters: Dict[str, Union[str, int, List, Set]])
key = self._gen_field_key(name)
if (not isinstance(candidates, List)) and (not isinstance(candidates, Set)):
candidates = list(candidates)
if desc.data_type == GlobalMetadataDesc.DTYPE_ARRAY:
if desc.data_type == DataType.ARRAY:
# https://github.com/milvus-io/milvus/discussions/35279
# `array_contains_any` requires milvus >= 2.4.3 and is not supported in local(aka lite) mode.
ret_str += f'array_contains_any({key}, {candidates}) and '
Expand Down
34 changes: 34 additions & 0 deletions lazyllm/tools/rag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,37 @@ def generic_process_filters(nodes: List[DocNode], filters: Dict[str, Union[str,
else:
res.append(node)
return res

def sparse2normal(embedding: Union[Dict[int, float], List[Tuple[int, float]]], dim: int) -> List[float]:
if not embedding:
return []

new_embedding = [0] * dim
if isinstance(embedding, dict):
for idx, val in embedding.items():
new_embedding[int(idx)] = val
elif isinstance(embedding, list) and isinstance(embedding[0], tuple):
for pair in embedding:
new_embedding[int(pair[0])] = pair[1]
else:
raise TypeError(f'unsupported embedding datatype `{type(embedding[0])}`')

return new_embedding

def is_sparse(embedding: Union[Dict[int, float], List[Tuple[int, float]], List[float]]) -> bool:
if isinstance(embedding, dict):
return True

if not isinstance(embedding, list):
raise TypeError(f'unsupported embedding type `{type(embedding)}`')

if len(embedding) == 0:
raise ValueError('empty embedding type is not determined.')

if isinstance(embedding[0], tuple):
return True

if isinstance(embedding[0], float) or isinstance(embedding[0], int):
return False

raise TypeError(f'unsupported embedding type `{type(embedding[0])}`')
25 changes: 24 additions & 1 deletion tests/basic_tests/test_rag_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from lazyllm.tools.rag.utils import generic_process_filters
from lazyllm.tools.rag.doc_node import DocNode
from lazyllm.tools.rag.utils import _FileNodeIndex
from lazyllm.tools.rag.utils import _FileNodeIndex, sparse2normal, is_sparse
from lazyllm.tools.rag.store_base import LAZY_ROOT_NAME
from lazyllm.tools.rag.global_metadata import RAG_DOC_PATH
import unittest
Expand All @@ -24,6 +24,29 @@ def test_generic_process_filters(self):
res = generic_process_filters(nodes, {'k2': 'v6'})
assert len(res) == 0

def test_sparse2normal(self):
embedding = {1: 3, 5: 12}
dim = 6
res = sparse2normal(embedding, dim)
assert len(res) == dim
assert res == [0, 3, 0, 0, 0, 12]

embedding = [(0, 9), (2, 14), (4, 28)]
dim = 8
res = sparse2normal(embedding, dim)
assert len(res) == dim
assert res == [9, 0, 14, 0, 28, 0, 0, 0]

def test_is_sparse(self):
embedding = {1: 3, 5: 12}
assert is_sparse(embedding)

embedding = [(0, 9), (2, 14), (4, 28)]
assert is_sparse(embedding)

embedding = [9, 0, 14, 0, 28, 0, 0, 0]
assert not is_sparse(embedding)

class TestFileNodeIndex(unittest.TestCase):
def setUp(self):
self.index = _FileNodeIndex()
Expand Down
Loading