Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vectorstores] Implement BaseStore interface for document storage #118

Merged
merged 3 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from langchain_google_vertexai.vectorstores.document_storage import (
DataStoreDocumentStorage,
GCSDocumentStorage,
)
from langchain_google_vertexai.vectorstores.vectorstores import (
VectorSearchVectorStore,
VectorSearchVectorStoreDatastore,
Expand All @@ -8,4 +12,6 @@
"VectorSearchVectorStore",
"VectorSearchVectorStoreDatastore",
"VectorSearchVectorStoreGCS",
"DataStoreDocumentStorage",
"GCSDocumentStorage",
]
Original file line number Diff line number Diff line change
@@ -1,51 +1,48 @@
from __future__ import annotations

import json
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Sequence, Tuple

from google.cloud import storage # type: ignore[attr-defined, unused-ignore]
from langchain_core.documents import Document
from langchain_core.stores import BaseStore

if TYPE_CHECKING:
from google.cloud import datastore # type: ignore[attr-defined, unused-ignore]


class DocumentStorage(ABC):
class DocumentStorage(BaseStore[str, Document]):
"""Abstract interface of a key, text storage for retrieving documents."""

@abstractmethod
def get_by_id(self, document_id: str) -> Document | None:
"""Gets a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
Returns:
Document if found, otherwise None.
"""
raise NotImplementedError()

@abstractmethod
def store_by_id(self, document_id: str, document: Document):
"""Stores a document associated to a document_id.
class GCSDocumentStorage(DocumentStorage):
"""Stores documents in Google Cloud Storage.
For each pair id, document_text the name of the blob will be {prefix}/{id} stored
in plain text format.
"""

def __init__(
self, bucket: storage.Bucket, prefix: Optional[str] = "documents"
) -> None:
"""Constructor.
Args:
document_id: Id of the document to be stored.
document: Document to be stored.
bucket: Bucket where the documents will be stored.
prefix: Prefix that is prepended to all document names.
"""
raise NotImplementedError()
super().__init__()
self._bucket = bucket
self._prefix = prefix

def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None:
"""Stores a series of documents using each keys

def batch_store_by_id(self, ids: List[str], documents: List[Document]) -> None:
"""Stores a list of ids and documents in batch.
The default implementation only loops to the individual `store_by_id`.
Subclasses that have faster ways to store data via batch uploading should
implement the proper way.
Args:
ids: List of ids for the text.
documents: List of documents.
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
"""
for id_, document in zip(ids, documents):
self.store_by_id(id_, document)
for key, value in key_value_pairs:
self._set_one(key, value)

def batch_get_by_id(self, ids: List[str]) -> List[Document | None]:
def mget(self, keys: Sequence[str]) -> List[Optional[Document]]:
"""Gets a batch of documents by id.
The default implementation only loops `get_by_id`.
Subclasses that have faster ways to retrieve data by batch should implement
Expand All @@ -56,36 +53,35 @@ def batch_get_by_id(self, ids: List[str]) -> List[Document | None]:
List of documents. If the key id is not found for any id record returns a
None instead.
"""
return [self.get_by_id(id_) for id_ in ids]
return [self._get_one(key) for key in keys]

def mdelete(self, keys: Sequence[str]) -> None:
"""Deletes a batch of documents by id.

class GCSDocumentStorage(DocumentStorage):
"""Stores documents in Google Cloud Storage.
For each pair id, document_text the name of the blob will be {prefix}/{id} stored
in plain text format.
"""
Args:
keys: List of ids for the text.
"""
for key in keys:
self._delete_one(key)

def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""Yields the keys present in the storage.

def __init__(
self, bucket: storage.Bucket, prefix: Optional[str] = "documents"
) -> None:
"""Constructor.
Args:
bucket: Bucket where the documents will be stored.
prefix: Prefix that is prepended to all document names.
prefix: Ignored. Uses the prefix provided in the constructor.
"""
super().__init__()
self._bucket = bucket
self._prefix = prefix
for blob in self._bucket.list_blobs(prefix=self._prefix):
yield blob.name.split("/")[-1]

def get_by_id(self, document_id: str) -> Document | None:
def _get_one(self, key: str) -> Document | None:
"""Gets the text of a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
key: Id of the document to get from the storage.
Returns:
Document if found, otherwise None.
"""

blob_name = self._get_blob_name(document_id)
blob_name = self._get_blob_name(key)
existing_blob = self._bucket.get_blob(blob_name)

if existing_blob is None:
Expand All @@ -95,19 +91,29 @@ def get_by_id(self, document_id: str) -> Document | None:
document_json: Dict[str, Any] = json.loads(document_str)
return Document(**document_json)

def store_by_id(self, document_id: str, document: Document) -> None:
def _set_one(self, key: str, value: Document) -> None:
"""Stores a document text associated to a document_id.
Args:
document_id: Id of the document to be stored.
key: Id of the document to be stored.
document: Document to be stored.
"""
blob_name = self._get_blob_name(document_id)
blob_name = self._get_blob_name(key)
new_blow = self._bucket.blob(blob_name)

document_json = document.dict()
document_json = value.dict()
document_text = json.dumps(document_json)
new_blow.upload_from_string(document_text)

def _delete_one(self, key: str) -> None:
"""Deletes one document by its key.

Args:
key (str): Id of the document to delete.
"""
blob_name = self._get_blob_name(key)
blob = self._bucket.blob(blob_name)
blob.delete()

def _get_blob_name(self, document_id: str) -> str:
"""Builds a blob name using the prefix and the document_id.
Args:
Expand Down Expand Up @@ -139,56 +145,22 @@ def __init__(
self._metadata_property_name = metadata_property_name
self._kind = kind

def get_by_id(self, document_id: str) -> Document | None:
"""Gets the text of a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
Returns:
Text of the document if found, otherwise None.
"""
key = self._client.key(self._kind, document_id)
entity = self._client.get(key)

if entity is None:
return None

return Document(
page_content=entity[self._text_property_name],
metadata=self._convert_entity_to_dict(entity[self._metadata_property_name]),
)

def store_by_id(self, document_id: str, document: Document) -> None:
"""Stores a document text associated to a document_id.
Args:
document_id: Id of the document to be stored.
text: Text of the document to be stored.
"""
with self._client.transaction():
key = self._client.key(self._kind, document_id)

entity = self._client.entity(key=key)
entity[self._text_property_name] = document.page_content
entity[self._metadata_property_name] = document.metadata

self._client.put(entity)

def batch_get_by_id(self, ids: List[str]) -> List[Document | None]:
def mget(self, keys: Sequence[str]) -> List[Optional[Document]]:
"""Gets a batch of documents by id.
Args:
ids: List of ids for the text.
Returns:
List of texts. If the key id is not found for any id record returns a None
instead.
"""
keys = [self._client.key(self._kind, id_) for id_ in ids]
ds_keys = [self._client.key(self._kind, id_) for id_ in keys]

# TODO: Handle when a key is not present
entities = self._client.get_multi(keys)
entities = self._client.get_multi(ds_keys)

# Entities are not sorted by key by default, the order is unclear. This orders
# the list by the id retrieved.
entity_id_lookup = {entity.key.id_or_name: entity for entity in entities}
entities = [entity_id_lookup[id_] for id_ in ids]
entities = [entity_id_lookup.get(id_) for id_ in keys]

return [
Document(
Expand All @@ -197,15 +169,19 @@ def batch_get_by_id(self, ids: List[str]) -> List[Document | None]:
entity[self._metadata_property_name]
),
)
if entity is not None
else None
for entity in entities
]

def batch_store_by_id(self, ids: List[str], documents: List[Document]) -> None:
"""Stores a list of ids and documents in batch.
def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None:
"""Stores a series of documents using each keys

Args:
ids: List of ids for the text.
texts: List of texts.
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
"""
ids = [key for key, _ in key_value_pairs]
documents = [document for _, document in key_value_pairs]

with self._client.transaction():
keys = [self._client.key(self._kind, id_) for id_ in ids]
Expand All @@ -219,6 +195,27 @@ def batch_store_by_id(self, ids: List[str], documents: List[Document]) -> None:

self._client.put_multi(entities)

def mdelete(self, keys: Sequence[str]) -> None:
"""Deletes a sequence of documents by key.

Args:
keys (Sequence[str]): A sequence of keys to delete.
"""
with self._client.transaction():
keys = [self._client.key(self._kind, id_) for id_ in keys]
self._client.delete_multi(keys)

def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""Yields the keys of all documents in the storage.

Args:
prefix: Ignored
"""
query = self._client.query(kind=self._kind)
query.keys_only()
for entity in query.fetch():
yield str(entity.key.id_or_name)

def _convert_entity_to_dict(self, entity: datastore.Entity) -> Dict[str, Any]:
"""Recursively transform an entity into a plain dictionary."""
from google.cloud import datastore # type: ignore[attr-defined, unused-ignore]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

from langchain_google_vertexai.vectorstores._document_storage import (
DataStoreDocumentStorage,
DocumentStorage,
GCSDocumentStorage,
)
from langchain_google_vertexai.vectorstores._sdk_manager import VectorSearchSDKManager
from langchain_google_vertexai.vectorstores._searcher import (
Searcher,
VectorSearchSearcher,
)
from langchain_google_vertexai.vectorstores.document_storage import (
DataStoreDocumentStorage,
DocumentStorage,
GCSDocumentStorage,
)


class _BaseVertexAIVectorStore(VectorStore):
Expand Down Expand Up @@ -115,19 +115,19 @@ def similarity_search_by_vector_with_score(
embeddings=[embedding], k=k, filter_=filter, numeric_filter=numeric_filter
)

results = []

for neighbor_id, distance in neighbors_list[0]:
document = self._document_storage.get_by_id(neighbor_id)
keys = [key for key, _ in neighbors_list[0]]
distances = [distance for _, distance in neighbors_list[0]]
documents = self._document_storage.mget(keys)

if document is None:
raise ValueError(
f"Document with id {neighbor_id} not found in document" "storage."
)

results.append((document, distance))

return results
if all(document is not None for document in documents):
# Ignore typing because mypy doesn't seem to be able to identify that
# in documents there is no possibility to have None values with the
# check above.
return list(zip(documents, distances)) # type: ignore
else:
missing_docs = [key for key, doc in zip(keys, documents) if doc is None]
message = f"Documents with ids: {missing_docs} not found in the storage"
raise ValueError(message)

def similarity_search(
self,
Expand Down Expand Up @@ -197,7 +197,7 @@ def add_texts(
for text, metadata in zip(texts, metadatas)
]

self._document_storage.batch_store_by_id(ids=ids, documents=documents)
self._document_storage.mset(list(zip(ids, documents)))

embeddings = self._embeddings.embed_documents(texts)

Expand Down
Loading
Loading