Skip to content

Commit

Permalink
Restore embedding model caching (#581)
Browse files Browse the repository at this point in the history
* Restore embedding model caching

* Remove import
  • Loading branch information
DavidMStraub authored Nov 27, 2024
1 parent b6af8f6 commit 2b33c73
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 18 deletions.
10 changes: 2 additions & 8 deletions gramps_webapi/api/search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from flask import current_app

from .indexer import SearchIndexer, SemanticSearchIndexer, SearchIndexerBase
from .embeddings import embedding_function_factory


def get_search_indexer(tree: str, semantic: bool = False) -> SearchIndexerBase:
Expand All @@ -50,15 +49,10 @@ def get_search_indexer(tree: str, semantic: bool = False) -> SearchIndexerBase:
if not path.exists() and not path.parent.exists():
path.parent.mkdir(parents=True, exist_ok=True)
if semantic:
model = current_app.config.get("VECTOR_EMBEDDING_MODEL")
model = current_app.config.get("_INITIALIZED_VECTOR_EMBEDDING_MODEL")
if not model:
raise ValueError("VECTOR_EMBEDDING_MODEL option not set")
try:
embedding_function = embedding_function_factory(model)
except OSError:
raise ValueError(f"Failed initializing model {model}")
# cache on app instance
return SemanticSearchIndexer(
db_url=db_url, tree=tree, embedding_function=embedding_function
db_url=db_url, tree=tree, embedding_function=model.encode
)
return SearchIndexer(db_url=db_url, tree=tree)
9 changes: 0 additions & 9 deletions gramps_webapi/api/search/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@
from ..util import get_logger


def embedding_function_factory(model_name: str):
model = load_model(model_name)

def embedding_function(queries: list[str]):
return model.encode(queries)

return embedding_function


def load_model(model_name: str):
"""Load the sentence transformer model.
Expand Down
4 changes: 3 additions & 1 deletion gramps_webapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def close_user_db_connection(exception) -> None:
user_db.session.remove() # pylint: disable=no-member

if app.config.get("VECTOR_EMBEDDING_MODEL"):
load_model(app.config["VECTOR_EMBEDDING_MODEL"])
app.config["_INITIALIZED_VECTOR_EMBEDDING_MODEL"] = load_model(
app.config["VECTOR_EMBEDDING_MODEL"]
)

@app.route("/ready", methods=["GET"])
def ready():
Expand Down

0 comments on commit 2b33c73

Please sign in to comment.