-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Milvus2DocumentStore compatible with pymilvus>=2.0.0 #2126
Changes from 6 commits
5ea553e
29aa16e
9c569fe
df62d25
f80544d
ae47bc0
ccc8370
f3cb065
478fd08
31f7819
171a69c
b86449a
8449aef
ad343ec
f9c4a13
c6d57d3
766d78c
e4df42c
c897ca8
fed7198
0fabd18
9af14ae
07469e6
879ba17
6122902
47a5f46
a628ccc
6514d0c
93d2fc7
2517bd1
4fa8434
dbcc2ca
64a61aa
e5af6a9
221bc9c
4415c45
b79ca80
246f236
a9c3aa3
e612b26
a137e13
0d3ccda
253ac5b
2011f34
a1da6f8
fa6ae18
d4836c1
88ed855
6bce535
4d49822
54cb14e
3a9a063
5fdd83b
838d715
7732186
fa93ba4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
from tqdm import tqdm | ||
|
||
try: | ||
from pymilvus import FieldSchema, CollectionSchema, Collection, connections | ||
from pymilvus import FieldSchema, CollectionSchema, Collection, connections, utility | ||
from pymilvus.client.abstract import QueryResult | ||
from pymilvus.client.types import DataType | ||
except (ImportError, ModuleNotFoundError) as ie: | ||
|
@@ -45,7 +45,7 @@ class Milvus2DocumentStore(SQLDocumentStore): | |
|
||
Usage: | ||
1. Start a Milvus service via docker (see https://milvus.io/docs/v2.0.0/install_standalone-docker.md) | ||
2. Run pip install pymilvus===2.0.0rc6 | ||
2. Run pip install pymilvus>=2.0.0 | ||
3. Init a Milvus2DocumentStore() in Haystack | ||
|
||
Overview: | ||
|
@@ -210,8 +210,7 @@ def _create_collection_and_index_if_not_exist( | |
index_param = index_param or self.index_param | ||
custom_fields = self.custom_fields or [] | ||
|
||
connection = connections.get_connection() | ||
has_collection = connection.has_collection(collection_name=index) | ||
has_collection = utility.has_collection(collection_name=index) | ||
if not has_collection: | ||
fields = [ | ||
FieldSchema(name=self.id_field, dtype=DataType.INT64, is_primary=True, auto_id=True), | ||
|
@@ -226,8 +225,7 @@ def _create_collection_and_index_if_not_exist( | |
|
||
collection_schema = CollectionSchema(fields=fields) | ||
else: | ||
resp = connection.describe_collection(index) | ||
collection_schema = CollectionSchema.construct_from_dict(resp) | ||
collection_schema = None | ||
|
||
collection = Collection(name=index, schema=collection_schema) | ||
|
||
|
@@ -279,7 +277,6 @@ def write_documents( | |
assert ( | ||
duplicate_documents in self.duplicate_documents_options | ||
), f"duplicate_documents parameter must be {', '.join(self.duplicate_documents_options)}" | ||
self._create_collection_and_index_if_not_exist(index=index, index_param=index_param) | ||
field_map = self._create_document_field_map() | ||
|
||
if len(documents) == 0: | ||
|
@@ -294,19 +291,6 @@ def write_documents( | |
with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar: | ||
mutation_result: Any = None | ||
|
||
if add_vectors: | ||
|
||
connection = connections.get_connection() | ||
field_to_idx, field_to_type = self._get_field_to_idx(connection, index) | ||
|
||
records: List[Dict[str, Any]] = [ | ||
{ | ||
"name": field_name, | ||
"type": dtype, | ||
"values": [], | ||
} | ||
for field_name, dtype in field_to_type.items() | ||
] | ||
for document_batch in batched_documents: | ||
if add_vectors: | ||
doc_ids = [] | ||
|
@@ -322,21 +306,11 @@ def write_documents( | |
f"Format of supplied document embedding {type(doc.embedding)} is not " | ||
f"supported. Please use list or numpy.ndarray" | ||
) | ||
records[field_to_idx[self.embedding_field]]["values"] = embeddings | ||
for k, v in field_to_idx.items(): | ||
if k == self.embedding_field: | ||
continue | ||
if k in doc.meta: | ||
records[v]["values"].append(doc.meta[k]) | ||
else: | ||
# TODO: check whether to throw error or not? | ||
pass | ||
|
||
if duplicate_documents == "overwrite": | ||
existing_docs = super().get_documents_by_id(ids=doc_ids, index=index) | ||
self._delete_vector_ids_from_milvus(documents=existing_docs, index=index) | ||
|
||
mutation_result = connection.insert(index, records) | ||
mutation_result = self.collection.insert([embeddings]) | ||
|
||
docs_to_write_in_sql = [] | ||
|
||
|
@@ -355,20 +329,6 @@ def write_documents( | |
# if duplicate_documents == 'overwrite': | ||
# connection.compact(collection_name=index) | ||
|
||
@staticmethod | ||
def _get_field_to_idx(connection, index): | ||
resp = connection.describe_collection(index) | ||
collection_schema = CollectionSchema.construct_from_dict(resp) | ||
field_to_idx: Dict[str, int] = {} | ||
field_to_type: Dict[str, DataType] = {} | ||
count = 0 | ||
for idx, field in enumerate(collection_schema.fields): | ||
if not field.is_primary: | ||
field_to_idx[field.name] = count | ||
field_to_type[field.name] = field.dtype | ||
count = count + 1 | ||
return field_to_idx, field_to_type | ||
|
||
def update_embeddings( | ||
self, | ||
retriever: "BaseRetriever", | ||
|
@@ -393,16 +353,13 @@ def update_embeddings( | |
:return: None | ||
""" | ||
index = index or self.index | ||
self._create_collection_and_index_if_not_exist(index) | ||
|
||
document_count = self.get_document_count(index=index) | ||
if document_count == 0: | ||
logger.warning("Calling DocumentStore.update_embeddings() on an empty index") | ||
return | ||
|
||
logger.info(f"Updating embeddings for {document_count} docs...") | ||
connection = connections.get_connection() | ||
field_to_idx, field_to_type = self._get_field_to_idx(connection, index) | ||
|
||
result = self._query( | ||
index=index, | ||
|
@@ -416,32 +373,13 @@ def update_embeddings( | |
total=document_count, disable=not self.progress_bar, position=0, unit=" docs", desc="Updating Embedding" | ||
) as progress_bar: | ||
for document_batch in batched_documents: | ||
records: List[Dict[str, Any]] = [ | ||
{ | ||
"name": field_name, | ||
"type": dtype, | ||
"values": [], | ||
} | ||
for field_name, dtype in field_to_type.items() | ||
] | ||
self._delete_vector_ids_from_milvus(documents=document_batch, index=index) | ||
|
||
embeddings = retriever.embed_documents(document_batch) # type: ignore | ||
embeddings_list = [embedding.tolist() for embedding in embeddings] | ||
assert len(document_batch) == len(embeddings_list) | ||
|
||
records[field_to_idx[self.embedding_field]]["values"] = embeddings_list | ||
for doc in document_batch: | ||
for k, v in field_to_idx.items(): | ||
if k == self.embedding_field: | ||
continue | ||
if k in doc.meta: | ||
records[v]["values"].append(doc.meta[k]) | ||
else: | ||
# TODO: check whether to throw error or not? | ||
pass | ||
|
||
mutation_result = connection.insert(index, records) | ||
mutation_result = self.collection.insert([embeddings_list]) | ||
|
||
vector_id_map = {} | ||
for vector_id, doc in zip(mutation_result.primary_keys, document_batch): | ||
|
@@ -478,43 +416,22 @@ def query_by_embedding( | |
raise NotImplementedError("Milvus2DocumentStore does not support headers.") | ||
|
||
index = index or self.index | ||
connection = connections.get_connection() | ||
has_collection = connection.has_collection(collection_name=index) | ||
has_collection = utility.has_collection(collection_name=index) | ||
if not has_collection: | ||
raise Exception("No index exists. Use 'update_embeddings()` to create an index.") | ||
|
||
if return_embedding is None: | ||
return_embedding = self.return_embedding | ||
|
||
connection.load_collection(index) | ||
|
||
query_emb = query_emb.reshape(1, -1).astype(np.float32) | ||
|
||
dsl: Dict[str, Any] = { | ||
"bool": { | ||
"must": [ | ||
{ | ||
"vector": { | ||
self.embedding_field: { | ||
"metric_type": self.metric_type, | ||
"params": self.search_param, | ||
"query": query_emb.tolist(), | ||
"topk": top_k, | ||
} | ||
} | ||
} | ||
] | ||
} | ||
} | ||
self.collection.load() | ||
|
||
if filters is not None: | ||
for k, v in filters.items(): | ||
dsl["bool"]["must"].append({"term": {k: v}}) | ||
query_emb = query_emb.reshape(-1).astype(np.float32) | ||
|
||
search_result: QueryResult = connection.search( | ||
collection_name=index, | ||
dsl=dsl, | ||
fields=[self.id_field], | ||
search_result: QueryResult = self.collection.search( | ||
data=[query_emb.tolist()], | ||
anns_field=self.embedding_field, | ||
param={"metric_type": self.metric_type, **self.search_param}, | ||
limit=top_k, | ||
) | ||
|
||
vector_ids_for_query = [] | ||
|
@@ -554,19 +471,12 @@ def delete_documents( | |
index = index or self.index | ||
super().delete_documents(index=index, filters=filters) | ||
|
||
connection = connections.get_connection() | ||
has_collection = connection.has_collection(collection_name=index) | ||
if not has_collection: | ||
logger.warning("No index exists. Use 'update_embeddings()` to create an index.") | ||
if filters: | ||
existing_docs = super().get_all_documents(filters=filters, index=index) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to use Generator? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sounds like a good idea. The feature is now added. |
||
self._delete_vector_ids_from_milvus(documents=existing_docs, index=index) | ||
else: | ||
if filters: | ||
existing_docs = super().get_all_documents(filters=filters, index=index) | ||
self._delete_vector_ids_from_milvus(documents=existing_docs, index=index) | ||
else: | ||
connection.drop_collection(collection_name=index) | ||
|
||
# TODO: Equivalent in 2.0? | ||
# self.milvus_server.compact(collection_name=index) | ||
self.collection.drop() | ||
self.collection = self._create_collection_and_index_if_not_exist(self.index) | ||
|
||
def get_all_documents_generator( | ||
self, | ||
|
@@ -681,13 +591,9 @@ def _populate_embeddings_to_docs(self, docs: List[Document], index: Optional[str | |
if len(docs_with_vector_ids) == 0: | ||
return | ||
|
||
connection = connections.get_connection() | ||
connection.load_collection(index) | ||
|
||
ids = [str(doc.meta.get("vector_id")) for doc in docs_with_vector_ids] # type: ignore | ||
|
||
search_result: QueryResult = connection.query( | ||
collection_name=index, | ||
search_result: QueryResult = self.collection.search( | ||
expr=f'{self.id_field} in [ {",".join(ids)} ]', | ||
output_fields=[self.embedding_field], | ||
) | ||
|
@@ -703,25 +609,12 @@ def _delete_vector_ids_from_milvus(self, documents: List[Document], index: Optio | |
if "vector_id" in doc.meta: | ||
existing_vector_ids.append(str(doc.meta["vector_id"])) | ||
|
||
if len(existing_vector_ids) > 0: | ||
# TODO: adjust when Milvus 2.0 is released and supports deletion of vectors again | ||
# (https://github.com/milvus-io/milvus/issues/7130) | ||
raise NotImplementedError("Milvus 2.0rc is not yet supporting the deletion of vectors.") | ||
# expression = f'{self.id_field} in [ {",".join(existing_vector_ids)} ]' | ||
# res = self.collection.delete(expression) | ||
# assert len(res) == len(existing_vector_ids) | ||
self.collection.delete(f"{self.id_field} in [ {','.join(existing_vector_ids)} ]") | ||
|
||
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int: | ||
""" | ||
Return the count of embeddings in the document store. | ||
""" | ||
if filters: | ||
raise Exception("filters are not supported for get_embedding_count in MilvusDocumentStore.") | ||
index = index or self.index | ||
|
||
connection = connections.get_connection() | ||
stats = connection.get_collection_stats(index) | ||
embedding_count = stats["row_count"] | ||
if embedding_count is None: | ||
embedding_count = 0 | ||
return embedding_count | ||
return self.collection.num_entities |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should change the name of this class too.
In addition, the line above could also change from
2. Run pip install pymilvus>=2.0.0
to2. Run pip install farm-haystack[milvus]
, and I would put a similar docstring into theMilvus1DocumentStore
class as wellThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, done