Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust for mypy and tests #3

Merged
merged 2 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions libs/langchain/langchain/vectorstores/sqlitevss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sqlite3
import warnings
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Expand All @@ -18,9 +17,6 @@
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore

if TYPE_CHECKING:
import sqlite_vss # noqa # pylint: disable=unused-import

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -67,10 +63,10 @@ def create_table_if_not_exists(self) -> None:
f"""
CREATE TABLE IF NOT EXISTS {self._table}
(
text_id INT PRIMARY KEY AUTOINCREMENT,
text text,
metadata blob,
text_embedding blob
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT,
metadata BLOB,
text_embedding BLOB
)
;
"""
Expand Down Expand Up @@ -108,8 +104,11 @@ def add_texts(
kwargs: vectorstore specific parameters
"""
max_id = self._connection.execute(
f"SELECT max(text_id) as text_id FROM {self._table}"
).fetchone()["text_id"]
f"SELECT max(rowid) as rowid FROM {self._table}"
).fetchone()["rowid"]
if max_id is None: # no text added yet
max_id = 0

embeds = self._embedding.embed_documents(list(texts))
if not metadatas:
metadatas = [{} for _ in texts]
Expand All @@ -123,12 +122,11 @@ def add_texts(
data_input,
)
self._connection.commit()

# pulling every ids we just inserted
results = self._connection.execute(
f"SELECT text_id FROM {self._table} WHERE text_id > {max_id}"
f"SELECT rowid FROM {self._table} WHERE rowid > {max_id}"
)
return [row["text_id"] for row in results]
return [row["rowid"] for row in results]

def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
Expand All @@ -139,7 +137,7 @@ def similarity_search_with_score_by_vector(
metadata,
distance
FROM {self._table} e
INNER JOIN vss_{self._table} v on v.text_id = e.text_id
INNER JOIN vss_{self._table} v on v.rowid = e.rowid
WHERE vss_search(
v.text_embedding,
vss_search_params('{json.dumps(embedding)}', {k})
Expand All @@ -151,9 +149,8 @@ def similarity_search_with_score_by_vector(

documents = []
for row in results:
doc = Document(
page_content=row["text"], metadata=json.loads(row["metadata"])
)
metadata = json.loads(row["metadata"]) or {}
doc = Document(page_content=row["text"], metadata=metadata)
score = self._euclidean_relevance_score_fn(row["distance"])
documents.append((doc, score))

Expand Down Expand Up @@ -207,6 +204,8 @@ def from_texts(

@staticmethod
def create_connection(db_file: str) -> sqlite3.Connection:
import sqlite_vss

connection = sqlite3.connect(db_file)
connection.row_factory = sqlite3.Row
connection.enable_load_extension(True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_sqlitevss() -> None:
"""Test end to end construction and search."""
docsearch = _sqlite_vss_from_texts()
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata=None)]
assert output == [Document(page_content="foo", metadata={})]


@pytest.mark.requires("sqlite-vss")
Expand All @@ -44,7 +44,7 @@ def test_sqlitevss_with_score() -> None:
Document(page_content="bar", metadata={"page": 1}),
Document(page_content="baz", metadata={"page": 2}),
]
assert scores[0] < scores[1] < scores[2]
assert scores[0] > scores[1] > scores[2]


@pytest.mark.requires("sqlite-vss")
Expand All @@ -53,8 +53,6 @@ def test_sqlitevss_add_extra() -> None:
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _sqlite_vss_from_texts(metadatas=metadatas)

docsearch.add_texts(texts, metadatas)

output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6