Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AI chat endpoint #539

Merged
merged 45 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
bcae9ab
Add dummy chat endpoint
DavidMStraub Aug 12, 2024
164cb26
Implement RAG chat endpoint
DavidMStraub Aug 27, 2024
de2c537
Fix function signature
DavidMStraub Aug 27, 2024
1e35281
Install ai dependencies
DavidMStraub Aug 27, 2024
2c1096b
Add missing future import
DavidMStraub Aug 27, 2024
11efde2
Docker: torch cpu-only; add ST model
DavidMStraub Sep 2, 2024
35d3b5f
Add columns for AI quota/usage
DavidMStraub Sep 3, 2024
73eae34
Add AI quota
DavidMStraub Sep 3, 2024
cd18201
Faster chunking of indexer
DavidMStraub Sep 4, 2024
9894344
Add dummy test case
DavidMStraub Sep 4, 2024
235f7d3
Use build instead of sdist
DavidMStraub Sep 4, 2024
54a2507
Add optional dependencies in toml
DavidMStraub Sep 4, 2024
212d530
Add semantic search status to metadata
DavidMStraub Sep 6, 2024
ddf1514
Update semantic index on add/update/delete
DavidMStraub Sep 6, 2024
998bace
Add app_has_semantic_search
DavidMStraub Sep 6, 2024
4936cfb
Remove commented code
DavidMStraub Sep 6, 2024
bffd665
Start adding tests for chat
DavidMStraub Sep 6, 2024
717da3f
Fix tests
DavidMStraub Sep 6, 2024
c8daa19
Fix test
DavidMStraub Sep 6, 2024
b337783
Add future import
DavidMStraub Sep 6, 2024
5491839
Add to test
DavidMStraub Sep 7, 2024
2e55b14
Add openai mock test
DavidMStraub Sep 7, 2024
c88ba89
Limit chat endpoint to user groups
DavidMStraub Sep 7, 2024
af231c1
Limit LLM context length
DavidMStraub Sep 8, 2024
b14e8c9
Refactor text_semantic
DavidMStraub Sep 8, 2024
a58b19d
Refactor text_semantic
DavidMStraub Sep 8, 2024
9c03db0
Fix syntax error
DavidMStraub Sep 8, 2024
9e3c487
More refactoring & fixes
DavidMStraub Sep 8, 2024
584bdae
Improve text_semantic
DavidMStraub Sep 10, 2024
35e721e
More improvements to text_semantic
DavidMStraub Sep 10, 2024
c447455
More improvements to text_semantic
DavidMStraub Sep 10, 2024
faecaa9
Add doc strings
DavidMStraub Sep 12, 2024
630292f
Remove unused function
DavidMStraub Sep 12, 2024
07b7836
More improvements to text_semantic
DavidMStraub Sep 14, 2024
7a7950c
Amend alembic migration
DavidMStraub Sep 14, 2024
4e88070
Add more info to metadata
DavidMStraub Sep 14, 2024
dc70555
Update api spec for metadata
DavidMStraub Sep 15, 2024
018b22b
Improve progress callback for search indexer
DavidMStraub Sep 15, 2024
9680181
Raise error if semantic search produces no results
DavidMStraub Sep 15, 2024
d09a371
Prevent permanently granting UseChat
DavidMStraub Sep 17, 2024
fa6ca7a
More logging for reindex
DavidMStraub Sep 17, 2024
0ceaf94
Fix missing argument in callback
DavidMStraub Sep 26, 2024
cbf0907
Change embedding model in Dockerfile
DavidMStraub Sep 26, 2024
328a691
Preload embedding model on app init...n
DavidMStraub Sep 27, 2024
02c7c0b
Fix annotations for Python 3.8
DavidMStraub Sep 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: python -m pip install --upgrade pip setuptools wheel
run: python -m pip install --upgrade pip setuptools wheel build
- name: Create the source distribution
run: python setup.py sdist
run: python -m build
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@master
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
python -m pip install --upgrade pip wheel setuptools
pip install opencv-python
pip install -r requirements-dev.txt
pip install .
pip install .[ai]
pip list
- name: Test with pytest
run: pytest
11 changes: 10 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,19 @@ RUN wget https://github.com/gramps-project/addons/archive/refs/heads/master.zip
RUN python3 -m pip install --break-system-packages --no-cache-dir --extra-index-url https://www.piwheels.org/simple \
gunicorn

# install PyTorch - CPU only
RUN python3 -m pip install --break-system-packages --no-cache-dir --index-url https://download.pytorch.org/whl/cpu \
torch

# copy package source and install
COPY . /app/src
RUN python3 -m pip install --break-system-packages --no-cache-dir --extra-index-url https://www.piwheels.org/simple \
/app/src
/app/src[ai]

# download and cache sentence transformer model
RUN python3 -c "\
from sentence_transformers import SentenceTransformer; \
model = SentenceTransformer('intfloat/multilingual-e5-small')"

EXPOSE 5000

Expand Down
40 changes: 40 additions & 0 deletions alembic_users/versions/a8e57fe0d82e_add_columns_for_ai_quota.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Add coloumns for AI quota

Revision ID: a8e57fe0d82e
Revises: 84960b7d968c
Create Date: 2024-09-03 18:48:00.917543

"""

from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector


# revision identifiers, used by Alembic.
revision = "a8e57fe0d82e"
down_revision = "84960b7d968c"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
columns = [col["name"] for col in inspector.get_columns("trees")]
if "quota_ai" not in columns:
op.add_column("trees", sa.Column("quota_ai", sa.Integer(), nullable=True))
if "usage_ai" not in columns:
op.add_column("trees", sa.Column("usage_ai", sa.Integer(), nullable=True))
if "min_role_ai" not in columns:
op.add_column("trees", sa.Column("min_role_ai", sa.Integer(), nullable=True))
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("trees", "usage_ai")
op.drop_column("trees", "quota_ai")
op.drop_column("trees", "min_role_ai")
# ### end Alembic commands ###
21 changes: 16 additions & 5 deletions gramps_webapi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@

"""Command line interface for the Gramps web API."""

from __future__ import annotations

import logging
import os
import subprocess
import sys
import time
import warnings

import click
Expand Down Expand Up @@ -120,8 +123,13 @@ def migrate_db(ctx):

@cli.group("search", help="Manage the full-text search index.")
@click.option("--tree", help="Tree ID", default=None)
@click.option(
"--semantic/--fulltext",
help="Semantic rather than full-text search index",
default=False,
)
@click.pass_context
def search(ctx, tree):
def search(ctx, tree, semantic):
app = ctx.obj["app"]
if not tree:
if app.config["TREE"] == TREE_MULTI:
Expand All @@ -135,14 +143,16 @@ def search(ctx, tree):
tree = dbmgr.dirname
with app.app_context():
ctx.obj["db_manager"] = get_db_manager(tree=tree)
ctx.obj["search_indexer"] = get_search_indexer(tree=tree)
ctx.obj["search_indexer"] = get_search_indexer(tree=tree, semantic=semantic)


def progress_callback_count(current: int, total: int) -> None:
def progress_callback_count(current: int, total: int, prev: int | None = None) -> None:
if total == 0:
return
pct = int(100 * current / total)
pct_prev = int(100 * (current - 1) / total)
if prev is None:
prev = current - 1
pct_prev = int(100 * prev / total)
if current == 0 or pct != pct_prev:
LOG.info(f"Progress: {pct}%")

Expand All @@ -156,13 +166,14 @@ def index_full(ctx):
indexer = ctx.obj["search_indexer"]
db = db_manager.get_db().db

t0 = time.time()
try:
indexer.reindex_full(db, progress_cb=progress_callback_count)
except:
LOG.exception("Error during indexing")
finally:
db.close()
LOG.info("Done building search index.")
LOG.info(f"Done building search index in {time.time() - t0:.0f} seconds.")


@search.command("index-incremental")
Expand Down
3 changes: 2 additions & 1 deletion gramps_webapi/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

__version__ = "2.4.2"
# make sure to match this version with the one in apispec.yaml
__version__ = "2.5.0"
4 changes: 4 additions & 0 deletions gramps_webapi/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BookmarkResource,
BookmarksResource,
)
from .resources.chat import ChatResource
from .resources.citations import CitationResource, CitationsResource
from .resources.config import ConfigResource, ConfigsResource
from .resources.dna import PersonDnaMatchesResource
Expand Down Expand Up @@ -330,6 +331,9 @@ def register_endpt(resource: Type[Resource], url: str, name: str):
register_endpt(SearchResource, "/search/", "search")
register_endpt(SearchIndexResource, "/search/index/", "search_index")

# Chat
register_endpt(ChatResource, "/chat/", "chat")

# Config
register_endpt(
ConfigsResource,
Expand Down
160 changes: 160 additions & 0 deletions gramps_webapi/api/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Functions for working with large language models (LLMs)."""

from __future__ import annotations

from flask import current_app
from openai import OpenAI, RateLimitError, APIError

from ..search import get_search_indexer
from ..util import abort_with_message, get_logger


def get_client(config: dict) -> OpenAI:
"""Get an OpenAI client instance."""
if not config.get("LLM_MODEL"):
raise ValueError("No LLM specified")
return OpenAI(base_url=config.get("LLM_BASE_URL"))


def answer_prompt(prompt: str, system_prompt: str, config: dict | None = None) -> str:
"""Answer a question given a system prompt."""
if not config:
if current_app:
config = current_app.config
else:
raise ValueError("Outside of the app context, config needs to be provided")

messages = []

if system_prompt:
messages.append(
{
"role": "system",
"content": str(system_prompt),
}
)

messages.append(
{
"role": "user",
"content": str(prompt),
}
)

client = get_client(config=config)
model = config.get("LLM_MODEL")

try:
response = client.chat.completions.create(
messages=messages,
model=model,
)
except RateLimitError:
abort_with_message(500, "Chat API rate limit exceeded.")
except APIError:
abort_with_message(500, "Chat API error encountered.")
except Exception:
abort_with_message(500, "Unexpected error.")

try:
answer = response.to_dict()["choices"][0]["message"]["content"]
except (KeyError, IndexError):
abort_with_message(500, "Error parsing chat API response.")

return answer


def answer_prompt_with_context(prompt: str, context: str) -> str:

system_prompt = (
"You are an assistant for answering questions about a user's family history. "
"Use the following pieces of context retrieved from a genealogical database "
"to answer the question. "
"If you don't know the answer, just say that you don't know. "
"Use three sentences maximum and keep the answer concise."
"In your answer, preserve relative Markdown links."
)

system_prompt = f"""{system_prompt}\n\n{context}"""
return answer_prompt(prompt=prompt, system_prompt=system_prompt)


def contextualize_prompt(prompt: str, context: str) -> str:

system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)

system_prompt = f"""{system_prompt}\n\n{context}"""

return answer_prompt(prompt=prompt, system_prompt=system_prompt)


def retrieve(tree: str, prompt: str, include_private: bool, num_results: int = 10):
searcher = get_search_indexer(tree, semantic=True)
total, hits = searcher.search(
query=prompt,
page=1,
pagesize=num_results,
include_private=include_private,
include_content=True,
)
return [hit["content"] for hit in hits]


def answer_prompt_retrieve(
prompt: str,
tree: str,
include_private: bool,
history: list | None = None,
) -> str:
logger = get_logger()

if not history:
# no chat history present - we directly retrieve the context

search_results = retrieve(
prompt=prompt, tree=tree, include_private=include_private, num_results=20
)
if not search_results:
abort_with_message("Unexpected problem while retrieving context")

context = ""
max_length = current_app.config["LLM_MAX_CONTEXT_LENGTH"]
for search_result in search_results:
if len(context) + len(search_result) > max_length:
break
context += search_result + "\n\n"
context = context.strip()

logger.debug("Answering prompt '%s' with context '%s'", prompt, context)
logger.debug("Context length: %s characters", len(context))
return answer_prompt_with_context(prompt=prompt, context=context)

# chat history is present - we first need to call the LLM to merge the history
# and the prompt into a new, standalone prompt.

context = ""
for message in history:
if "role" not in message or "message" not in message:
raise ValueError(f"Invalid message format: {message}")
if message["role"].lower() in ["ai", "system", "assistant"]:
context += f"*Assistant message:* {message['message']}\n\n"
elif message["role"].lower() == "error":
pass
else:
context += f"*Human message:* {message['message']}\n\n"
context = context.strip()

logger.debug("Contextualizing prompt '%s' with context '%s'", prompt, context)
new_prompt = contextualize_prompt(prompt=prompt, context=context)
logger.debug("New prompt: '%s'", new_prompt)

# we can now feed the standalone prompt into the same function but without history.
return answer_prompt_retrieve(
prompt=new_prompt, tree=tree, include_private=include_private
)
13 changes: 13 additions & 0 deletions gramps_webapi/api/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .util import (
abort_with_message,
add_object,
app_has_semantic_search,
filter_missing_files,
fix_object_dict,
get_backlinks,
Expand Down Expand Up @@ -294,6 +295,12 @@ def put(self, handle: str) -> Response:
handle = _trans_dict["handle"]
class_name = _trans_dict["_class"]
indexer.add_or_update_object(handle, db_handle, class_name)
if app_has_semantic_search():
indexer: SearchIndexer = get_search_indexer(tree, semantic=True)
for _trans_dict in trans_dict:
handle = _trans_dict["handle"]
class_name = _trans_dict["_class"]
indexer.add_or_update_object(handle, db_handle, class_name)
return self.response(200, trans_dict, total_items=len(trans_dict))


Expand Down Expand Up @@ -471,6 +478,12 @@ def post(self) -> Response:
handle = _trans_dict["handle"]
class_name = _trans_dict["_class"]
indexer.add_or_update_object(handle, db_handle, class_name)
if app_has_semantic_search():
indexer: SearchIndexer = get_search_indexer(tree, semantic=True)
for _trans_dict in trans_dict:
handle = _trans_dict["handle"]
class_name = _trans_dict["_class"]
indexer.add_or_update_object(handle, db_handle, class_name)
return self.response(201, trans_dict, total_items=len(trans_dict))


Expand Down
Loading