Skip to content

Commit

Permalink
Implement AI chat endpoint (#539)
Browse files Browse the repository at this point in the history
* Add dummy chat endpoint

* Implement RAG chat endpoint

* Fix function signature

* Install ai dependencies

* Add missing future import

* Docker: torch cpu-only; add ST model

* Add columns for AI quota/usage

* Add AI quota

* Faster chunking of indexer

* Add dummy test case

* Use build instead of sdist

* Add optional dependencies in toml

* Add semantic search status to metadata

* Update semantic index on add/update/delete

* Add app_has_semantic_search

* Remove commented code

* Start adding tests for chat

* Fix tests

* Fix test

* Add future import

* Add to test

* Add openai mock test

* Limit chat endpoint to user groups

* Limit LLM context length

* Refactor text_semantic

* Refactor text_semantic

* Fix syntax error

* More refactoring & fixes

* Improve text_semantic

* More improvements to text_semantic

* More improvements to text_semantic

* Add doc strings

* Remove unused function

* More improvements to text_semantic

* Amend alembic migration

* Add more info to metadata

* Update api spec for metadata

* Improve progress callback for search indexer

* Raise error if semantic search produces no results

* Prevent permanently granting UseChat

* More logging for reindex

* Fix missing argument in callback

* Change embedding model in Dockerfile

* Preload embedding model on app init...n

* Fix annotations for Python 3.8
  • Loading branch information
DavidMStraub authored Sep 27, 2024
1 parent e05c31e commit 82cfa86
Show file tree
Hide file tree
Showing 39 changed files with 64,376 additions and 131 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: python -m pip install --upgrade pip setuptools wheel
run: python -m pip install --upgrade pip setuptools wheel build
- name: Create the source distribution
run: python setup.py sdist
run: python -m build
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@master
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
python -m pip install --upgrade pip wheel setuptools
pip install opencv-python
pip install -r requirements-dev.txt
pip install .
pip install .[ai]
pip list
- name: Test with pytest
run: pytest
11 changes: 10 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,19 @@ RUN wget https://github.com/gramps-project/addons/archive/refs/heads/master.zip
RUN python3 -m pip install --break-system-packages --no-cache-dir --extra-index-url https://www.piwheels.org/simple \
gunicorn

# install PyTorch - CPU only
RUN python3 -m pip install --break-system-packages --no-cache-dir --index-url https://download.pytorch.org/whl/cpu \
torch

# copy package source and install
COPY . /app/src
RUN python3 -m pip install --break-system-packages --no-cache-dir --extra-index-url https://www.piwheels.org/simple \
/app/src
/app/src[ai]

# download and cache sentence transformer model
RUN python3 -c "\
from sentence_transformers import SentenceTransformer; \
model = SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v2')"

EXPOSE 5000

Expand Down
40 changes: 40 additions & 0 deletions alembic_users/versions/a8e57fe0d82e_add_columns_for_ai_quota.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Add coloumns for AI quota
Revision ID: a8e57fe0d82e
Revises: 84960b7d968c
Create Date: 2024-09-03 18:48:00.917543
"""

from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector


# revision identifiers, used by Alembic.
revision = "a8e57fe0d82e"
down_revision = "84960b7d968c"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
columns = [col["name"] for col in inspector.get_columns("trees")]
if "quota_ai" not in columns:
op.add_column("trees", sa.Column("quota_ai", sa.Integer(), nullable=True))
if "usage_ai" not in columns:
op.add_column("trees", sa.Column("usage_ai", sa.Integer(), nullable=True))
if "min_role_ai" not in columns:
op.add_column("trees", sa.Column("min_role_ai", sa.Integer(), nullable=True))
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("trees", "usage_ai")
op.drop_column("trees", "quota_ai")
op.drop_column("trees", "min_role_ai")
# ### end Alembic commands ###
21 changes: 16 additions & 5 deletions gramps_webapi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@

"""Command line interface for the Gramps web API."""

from __future__ import annotations

import logging
import os
import subprocess
import sys
import time
import warnings

import click
Expand Down Expand Up @@ -120,8 +123,13 @@ def migrate_db(ctx):

@cli.group("search", help="Manage the full-text search index.")
@click.option("--tree", help="Tree ID", default=None)
@click.option(
"--semantic/--fulltext",
help="Semantic rather than full-text search index",
default=False,
)
@click.pass_context
def search(ctx, tree):
def search(ctx, tree, semantic):
app = ctx.obj["app"]
if not tree:
if app.config["TREE"] == TREE_MULTI:
Expand All @@ -135,14 +143,16 @@ def search(ctx, tree):
tree = dbmgr.dirname
with app.app_context():
ctx.obj["db_manager"] = get_db_manager(tree=tree)
ctx.obj["search_indexer"] = get_search_indexer(tree=tree)
ctx.obj["search_indexer"] = get_search_indexer(tree=tree, semantic=semantic)


def progress_callback_count(current: int, total: int) -> None:
def progress_callback_count(current: int, total: int, prev: int | None = None) -> None:
if total == 0:
return
pct = int(100 * current / total)
pct_prev = int(100 * (current - 1) / total)
if prev is None:
prev = current - 1
pct_prev = int(100 * prev / total)
if current == 0 or pct != pct_prev:
LOG.info(f"Progress: {pct}%")

Expand All @@ -156,13 +166,14 @@ def index_full(ctx):
indexer = ctx.obj["search_indexer"]
db = db_manager.get_db().db

t0 = time.time()
try:
indexer.reindex_full(db, progress_cb=progress_callback_count)
except:
LOG.exception("Error during indexing")
finally:
db.close()
LOG.info("Done building search index.")
LOG.info(f"Done building search index in {time.time() - t0:.0f} seconds.")


@search.command("index-incremental")
Expand Down
3 changes: 2 additions & 1 deletion gramps_webapi/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

__version__ = "2.4.2"
# make sure to match this version with the one in apispec.yaml
__version__ = "2.5.0"
4 changes: 4 additions & 0 deletions gramps_webapi/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BookmarkResource,
BookmarksResource,
)
from .resources.chat import ChatResource
from .resources.citations import CitationResource, CitationsResource
from .resources.config import ConfigResource, ConfigsResource
from .resources.dna import PersonDnaMatchesResource
Expand Down Expand Up @@ -330,6 +331,9 @@ def register_endpt(resource: Type[Resource], url: str, name: str):
register_endpt(SearchResource, "/search/", "search")
register_endpt(SearchIndexResource, "/search/index/", "search_index")

# Chat
register_endpt(ChatResource, "/chat/", "chat")

# Config
register_endpt(
ConfigsResource,
Expand Down
160 changes: 160 additions & 0 deletions gramps_webapi/api/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Functions for working with large language models (LLMs)."""

from __future__ import annotations

from flask import current_app
from openai import OpenAI, RateLimitError, APIError

from ..search import get_search_indexer
from ..util import abort_with_message, get_logger


def get_client(config: dict) -> OpenAI:
"""Get an OpenAI client instance."""
if not config.get("LLM_MODEL"):
raise ValueError("No LLM specified")
return OpenAI(base_url=config.get("LLM_BASE_URL"))


def answer_prompt(prompt: str, system_prompt: str, config: dict | None = None) -> str:
"""Answer a question given a system prompt."""
if not config:
if current_app:
config = current_app.config
else:
raise ValueError("Outside of the app context, config needs to be provided")

messages = []

if system_prompt:
messages.append(
{
"role": "system",
"content": str(system_prompt),
}
)

messages.append(
{
"role": "user",
"content": str(prompt),
}
)

client = get_client(config=config)
model = config.get("LLM_MODEL")

try:
response = client.chat.completions.create(
messages=messages,
model=model,
)
except RateLimitError:
abort_with_message(500, "Chat API rate limit exceeded.")
except APIError:
abort_with_message(500, "Chat API error encountered.")
except Exception:
abort_with_message(500, "Unexpected error.")

try:
answer = response.to_dict()["choices"][0]["message"]["content"]
except (KeyError, IndexError):
abort_with_message(500, "Error parsing chat API response.")

return answer


def answer_prompt_with_context(prompt: str, context: str) -> str:

system_prompt = (
"You are an assistant for answering questions about a user's family history. "
"Use the following pieces of context retrieved from a genealogical database "
"to answer the question. "
"If you don't know the answer, just say that you don't know. "
"Use three sentences maximum and keep the answer concise."
"In your answer, preserve relative Markdown links."
)

system_prompt = f"""{system_prompt}\n\n{context}"""
return answer_prompt(prompt=prompt, system_prompt=system_prompt)


def contextualize_prompt(prompt: str, context: str) -> str:

system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)

system_prompt = f"""{system_prompt}\n\n{context}"""

return answer_prompt(prompt=prompt, system_prompt=system_prompt)


def retrieve(tree: str, prompt: str, include_private: bool, num_results: int = 10):
searcher = get_search_indexer(tree, semantic=True)
total, hits = searcher.search(
query=prompt,
page=1,
pagesize=num_results,
include_private=include_private,
include_content=True,
)
return [hit["content"] for hit in hits]


def answer_prompt_retrieve(
prompt: str,
tree: str,
include_private: bool,
history: list | None = None,
) -> str:
logger = get_logger()

if not history:
# no chat history present - we directly retrieve the context

search_results = retrieve(
prompt=prompt, tree=tree, include_private=include_private, num_results=20
)
if not search_results:
abort_with_message("Unexpected problem while retrieving context")

context = ""
max_length = current_app.config["LLM_MAX_CONTEXT_LENGTH"]
for search_result in search_results:
if len(context) + len(search_result) > max_length:
break
context += search_result + "\n\n"
context = context.strip()

logger.debug("Answering prompt '%s' with context '%s'", prompt, context)
logger.debug("Context length: %s characters", len(context))
return answer_prompt_with_context(prompt=prompt, context=context)

# chat history is present - we first need to call the LLM to merge the history
# and the prompt into a new, standalone prompt.

context = ""
for message in history:
if "role" not in message or "message" not in message:
raise ValueError(f"Invalid message format: {message}")
if message["role"].lower() in ["ai", "system", "assistant"]:
context += f"*Assistant message:* {message['message']}\n\n"
elif message["role"].lower() == "error":
pass
else:
context += f"*Human message:* {message['message']}\n\n"
context = context.strip()

logger.debug("Contextualizing prompt '%s' with context '%s'", prompt, context)
new_prompt = contextualize_prompt(prompt=prompt, context=context)
logger.debug("New prompt: '%s'", new_prompt)

# we can now feed the standalone prompt into the same function but without history.
return answer_prompt_retrieve(
prompt=new_prompt, tree=tree, include_private=include_private
)
13 changes: 13 additions & 0 deletions gramps_webapi/api/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .util import (
abort_with_message,
add_object,
app_has_semantic_search,
filter_missing_files,
fix_object_dict,
get_backlinks,
Expand Down Expand Up @@ -294,6 +295,12 @@ def put(self, handle: str) -> Response:
handle = _trans_dict["handle"]
class_name = _trans_dict["_class"]
indexer.add_or_update_object(handle, db_handle, class_name)
if app_has_semantic_search():
indexer: SearchIndexer = get_search_indexer(tree, semantic=True)
for _trans_dict in trans_dict:
handle = _trans_dict["handle"]
class_name = _trans_dict["_class"]
indexer.add_or_update_object(handle, db_handle, class_name)
return self.response(200, trans_dict, total_items=len(trans_dict))


Expand Down Expand Up @@ -471,6 +478,12 @@ def post(self) -> Response:
handle = _trans_dict["handle"]
class_name = _trans_dict["_class"]
indexer.add_or_update_object(handle, db_handle, class_name)
if app_has_semantic_search():
indexer: SearchIndexer = get_search_indexer(tree, semantic=True)
for _trans_dict in trans_dict:
handle = _trans_dict["handle"]
class_name = _trans_dict["_class"]
indexer.add_or_update_object(handle, db_handle, class_name)
return self.response(201, trans_dict, total_items=len(trans_dict))


Expand Down
Loading

0 comments on commit 82cfa86

Please sign in to comment.