From 1bbb49e8612922a832ca6a50e17f9296024eb362 Mon Sep 17 00:00:00 2001 From: Philippe Oger Date: Thu, 31 Aug 2023 14:23:17 +0200 Subject: [PATCH 1/2] Adjust for mypy and tests --- .../langchain/vectorstores/sqlitevss.py | 30 ++++++++++--------- .../vectorstores/test_sqlitevss.py | 6 ++-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/sqlitevss.py b/libs/langchain/langchain/vectorstores/sqlitevss.py index a9adad2828603..59567fc0174eb 100644 --- a/libs/langchain/langchain/vectorstores/sqlitevss.py +++ b/libs/langchain/langchain/vectorstores/sqlitevss.py @@ -5,7 +5,6 @@ import sqlite3 import warnings from typing import ( - TYPE_CHECKING, Any, Iterable, List, @@ -18,8 +17,6 @@ from langchain.embeddings.base import Embeddings from langchain.vectorstores.base import VectorStore -if TYPE_CHECKING: - import sqlite_vss # noqa # pylint: disable=unused-import logger = logging.getLogger(__name__) @@ -67,10 +64,10 @@ def create_table_if_not_exists(self) -> None: f""" CREATE TABLE IF NOT EXISTS {self._table} ( - text_id INT PRIMARY KEY AUTOINCREMENT, - text text, - metadata blob, - text_embedding blob + rowid INTEGER PRIMARY KEY AUTOINCREMENT, + text TEXT, + metadata BLOB, + text_embedding BLOB ) ; """ @@ -108,8 +105,11 @@ def add_texts( kwargs: vectorstore specific parameters """ max_id = self._connection.execute( - f"SELECT max(text_id) as text_id FROM {self._table}" - ).fetchone()["text_id"] + f"SELECT max(rowid) as rowid FROM {self._table}" + ).fetchone()["rowid"] + if max_id is None: # no text added yet + max_id = 0 + embeds = self._embedding.embed_documents(list(texts)) if not metadatas: metadatas = [{} for _ in texts] @@ -123,12 +123,11 @@ def add_texts( data_input, ) self._connection.commit() - # pulling every ids we just inserted results = self._connection.execute( - f"SELECT text_id FROM {self._table} WHERE text_id > {max_id}" + f"SELECT rowid FROM {self._table} WHERE rowid > {max_id}" ) - return [row["text_id"] for row in results] + return [row["rowid"] for row in results] def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, **kwargs: Any @@ -139,7 +138,7 @@ def similarity_search_with_score_by_vector( metadata, distance FROM {self._table} e - INNER JOIN vss_{self._table} v on v.text_id = e.text_id + INNER JOIN vss_{self._table} v on v.rowid = e.rowid WHERE vss_search( v.text_embedding, vss_search_params('{json.dumps(embedding)}', {k}) @@ -151,8 +150,10 @@ def similarity_search_with_score_by_vector( documents = [] for row in results: + metadata = json.loads(row["metadata"]) or {} doc = Document( - page_content=row["text"], metadata=json.loads(row["metadata"]) + page_content=row["text"], + metadata=metadata ) score = self._euclidean_relevance_score_fn(row["distance"]) documents.append((doc, score)) @@ -207,6 +208,7 @@ def from_texts( @staticmethod def create_connection(db_file: str) -> sqlite3.Connection: + import sqlite_vss connection = sqlite3.connect(db_file) connection.row_factory = sqlite3.Row connection.enable_load_extension(True) diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_sqlitevss.py b/libs/langchain/tests/integration_tests/vectorstores/test_sqlitevss.py index fe124a9d821ed..ce0a304a9ad79 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_sqlitevss.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_sqlitevss.py @@ -27,7 +27,7 @@ def test_sqlitevss() -> None: """Test end to end construction and search.""" docsearch = _sqlite_vss_from_texts() output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo", metadata=None)] + assert output == [Document(page_content="foo", metadata={})] @pytest.mark.requires("sqlite-vss") @@ -44,7 +44,7 @@ def test_sqlitevss_with_score() -> None: Document(page_content="bar", metadata={"page": 1}), Document(page_content="baz", metadata={"page": 2}), ] - assert scores[0] < scores[1] < scores[2] + assert scores[0] > scores[1] > scores[2] @pytest.mark.requires("sqlite-vss") @@ -53,8 +53,6 @@ def test_sqlitevss_add_extra() -> None: texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] docsearch = _sqlite_vss_from_texts(metadatas=metadatas) - docsearch.add_texts(texts, metadatas) - output = docsearch.similarity_search("foo", k=10) assert len(output) == 6 From 29f3c162ca9e368316f404161877d8a51b5284fa Mon Sep 17 00:00:00 2001 From: Philippe Oger Date: Thu, 31 Aug 2023 14:34:12 +0200 Subject: [PATCH 2/2] Formatting --- libs/langchain/langchain/vectorstores/sqlitevss.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/sqlitevss.py b/libs/langchain/langchain/vectorstores/sqlitevss.py index 59567fc0174eb..5234ba8c2a330 100644 --- a/libs/langchain/langchain/vectorstores/sqlitevss.py +++ b/libs/langchain/langchain/vectorstores/sqlitevss.py @@ -17,7 +17,6 @@ from langchain.embeddings.base import Embeddings from langchain.vectorstores.base import VectorStore - logger = logging.getLogger(__name__) @@ -151,10 +150,7 @@ def similarity_search_with_score_by_vector( documents = [] for row in results: metadata = json.loads(row["metadata"]) or {} - doc = Document( - page_content=row["text"], - metadata=metadata - ) + doc = Document(page_content=row["text"], metadata=metadata) score = self._euclidean_relevance_score_fn(row["distance"]) documents.append((doc, score)) @@ -209,6 +205,7 @@ def from_texts( @staticmethod def create_connection(db_file: str) -> sqlite3.Connection: import sqlite_vss + connection = sqlite3.connect(db_file) connection.row_factory = sqlite3.Row connection.enable_load_extension(True)