Skip to content

Commit

Permalink
Merge pull request langchain-ai#2 from TileDB-Inc/npapa/fix-build
Browse files Browse the repository at this point in the history
Fix build failures
  • Loading branch information
NikolaosPapailiou authored Sep 14, 2023
2 parents 8705aed + 3d8be1e commit bd6c714
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 169 deletions.
2 changes: 0 additions & 2 deletions libs/langchain/langchain/vectorstores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
from langchain.vectorstores.supabase import SupabaseVectorStore
from langchain.vectorstores.tair import Tair
from langchain.vectorstores.tencentvectordb import TencentVectorDB
from langchain.vectorstores.tiledb import TileDB
from langchain.vectorstores.tigris import Tigris
from langchain.vectorstores.typesense import Typesense
from langchain.vectorstores.usearch import USearch
Expand Down Expand Up @@ -131,7 +130,6 @@
"StarRocks",
"SupabaseVectorStore",
"Tair",
"TileDB",
"Tigris",
"Typesense",
"USearch",
Expand Down
66 changes: 25 additions & 41 deletions libs/langchain/langchain/vectorstores/tiledb.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Wrapper around TileDB vector database."""
from __future__ import annotations

import os
import pickle
import random
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from ratelimiter import RateLimiter
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
import tiledb
Expand Down Expand Up @@ -111,18 +109,17 @@ def process_index_results(
key: [value] if not isinstance(value, list) else value
for key, value in filter.items()
}
if all(result_doc.metadata.get(key) in value for key, value in filter.items()):
if all(
result_doc.metadata.get(key) in value
for key, value in filter.items()
):
docs.append((result_doc, score))
else:
docs.append((result_doc, score))
docs_array.close()
score_threshold = kwargs.get("score_threshold")
if score_threshold is not None:
docs = [
(doc, score)
for doc, score in docs
if score <= score_threshold
]
docs = [(doc, score) for doc, score in docs if score <= score_threshold]
return docs[:k]

def similarity_search_with_score_by_vector(
Expand Down Expand Up @@ -151,14 +148,10 @@ def similarity_search_with_score_by_vector(
"""
d, i = self.index.query(
np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
k=k if filter is None else fetch_k
k=k if filter is None else fetch_k,
)
return self.process_index_results(
ids=i[0],
scores=d[0],
filter=filter,
k=k,
**kwargs
ids=i[0], scores=d[0], filter=filter, k=k, **kwargs
)

def similarity_search_with_score(
Expand Down Expand Up @@ -276,15 +269,17 @@ def max_marginal_relevance_search_with_score_by_vector(
"""
scores, indices = self.index.query(
np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
k=fetch_k if filter is None else fetch_k * 2
k=fetch_k if filter is None else fetch_k * 2,
)
results = self.process_index_results(
ids=indices[0],
scores=scores[0],
filter=filter,
k=fetch_k if filter is None else fetch_k * 2
k=fetch_k if filter is None else fetch_k * 2,
)
embeddings = [self.embedding.embed_documents([doc.page_content])[0] for doc, _ in results]
embeddings = [
self.embedding.embed_documents([doc.page_content])[0] for doc, _ in results
]
mmr_selected = maximal_marginal_relevance(
np.array([embedding], dtype=np.float32),
embeddings,
Expand Down Expand Up @@ -394,7 +389,7 @@ def __from(
vector_array_uri = f"{group.uri}/{VECTOR_ARRAY_NAME}"
docs_uri = f"{group.uri}/{DOCUMENTS_ARRAY_NAME}"
if ids is None:
ids = [random.randint(0, MAX_UINT64-1) for _ in texts]
ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]
external_ids = np.array(ids).astype(np.uint64)

input_vectors = np.array(embeddings).astype(np.float32)
Expand All @@ -403,13 +398,13 @@ def __from(
index_uri=vector_array_uri,
input_vectors=input_vectors,
external_ids=external_ids,
partitions=1
partitions=1,
)
group.add(vector_array_uri, name=VECTOR_ARRAY_NAME)

dim = tiledb.Dim(
name="id",
domain=(0, MAX_UINT64-1),
domain=(0, MAX_UINT64 - 1),
dtype=np.dtype(np.uint64),
)
dom = tiledb.Domain(dim)
Expand Down Expand Up @@ -437,7 +432,9 @@ def __from(
metadata_attr = np.empty([len(metadatas)], dtype=object)
i = 0
for metadata in metadatas:
metadata_attr[i] = np.frombuffer(pickle.dumps(metadata), dtype=np.uint8)
metadata_attr[i] = np.frombuffer(
pickle.dumps(metadata), dtype=np.uint8
)
i += 1
data["metadata"] = metadata_attr

Expand All @@ -460,13 +457,13 @@ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[boo

external_ids = np.array(ids).astype(np.uint64)
self.index.delete_batch(external_ids=external_ids)
return True

def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
rate_limiter: RateLimiter = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Expand All @@ -479,18 +476,12 @@ def add_texts(
Returns:
List of ids from adding the texts into the vectorstore.
"""
embeddings = []
if rate_limiter is None:
embeddings = self.embedding.embed_documents(texts)
else:
for i in range(len(texts)):
with rate_limiter:
embeddings.append(self.embedding.embed_documents(texts[i])[0])
embeddings = self.embedding.embed_documents(list(texts))
if ids is None:
ids = [random.randint(0, 100000) for _ in texts]
ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]

external_ids = np.array(ids).astype(np.uint64)
vectors = np.empty((len(embeddings)), dtype='O')
vectors = np.empty((len(embeddings)), dtype="O")
for i in range(len(embeddings)):
vectors[i] = np.array(embeddings[i], dtype=np.float32)
self.index.update_batch(vectors=vectors, external_ids=external_ids)
Expand Down Expand Up @@ -519,7 +510,6 @@ def from_texts(
ids: Optional[List[str]] = None,
metric: str = DEFAULT_METRIC,
array_uri: str = "/tmp/tiledb_array",
rate_limiter: RateLimiter = None,
**kwargs: Any,
) -> TileDB:
"""Construct a TileDB index from raw documents.
Expand All @@ -530,7 +520,6 @@ def from_texts(
metadatas: List of metadata dictionaries to associate with documents.
metric: Metric to use for indexing. Defaults to "euclidean".
array_uri: The URI to write the TileDB arrays
rate_limiter: RateLimiter for embeddings generation
Example:
.. code-block:: python
Expand All @@ -541,12 +530,7 @@ def from_texts(
index = TileDB.from_texts(texts, embeddings)
"""
embeddings = []
if rate_limiter is None:
embeddings = embedding.embed_documents(texts)
else:
for i in range(len(texts)):
with rate_limiter:
embeddings.append(embedding.embed_documents(texts[i])[0])
embeddings = embedding.embed_documents(texts)
return cls.__from(
texts, embeddings, embedding, metadatas, ids, metric, array_uri
)
Expand Down Expand Up @@ -610,5 +594,5 @@ def load(

return cls(embeddings, index, DEFAULT_METRIC, documents_array_uri)

def consolidate_updates(self):
def consolidate_updates(self) -> None:
self.index = self.index.consolidate_updates()
Loading

0 comments on commit bd6c714

Please sign in to comment.