-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add dummy chat endpoint * Implement RAG chat endpoint * Fix function signature * Install ai dependencies * Add missing future import * Docker: torch cpu-only; add ST model * Add columns for AI quota/usage * Add AI quota * Faster chunking of indexer * Add dummy test case * Use build instead of sdist * Add optional dependencies in toml * Add semantic search status to metadata * Update semantic index on add/update/delete * Add app_has_semantic_search * Remove commented code * Start adding tests for chat * Fix tests * Fix test * Add future import * Add to test * Add openai mock test * Limit chat endpoint to user groups * Limit LLM context length * Refactor text_semantic * Refactor text_semantic * Fix syntax error * More refactoring & fixes * Improve text_semantic * More improvements to text_semantic * More improvements to text_semantic * Add doc strings * Remove unused function * More improvements to text_semantic * Amend alembic migration * Add more info to metadata * Update api spec for metadata * Improve progress callback for search indexer * Raise error if semantic search produces no results * Prevent permanently granting UseChat * More logging for reindex * Fix missing argument in callback * Change embedding model in Dockerfile * Preload embedding model on app init...n * Fix annotations for Python 3.8
- Loading branch information
1 parent
e05c31e
commit 82cfa86
Showing
39 changed files
with
64,376 additions
and
131 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
40 changes: 40 additions & 0 deletions
40
alembic_users/versions/a8e57fe0d82e_add_columns_for_ai_quota.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
"""Add coloumns for AI quota | ||
Revision ID: a8e57fe0d82e | ||
Revises: 84960b7d968c | ||
Create Date: 2024-09-03 18:48:00.917543 | ||
""" | ||
|
||
from alembic import op | ||
import sqlalchemy as sa | ||
from sqlalchemy.engine.reflection import Inspector | ||
|
||
|
||
# revision identifiers, used by Alembic. | ||
revision = "a8e57fe0d82e" | ||
down_revision = "84960b7d968c" | ||
branch_labels = None | ||
depends_on = None | ||
|
||
|
||
def upgrade(): | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
conn = op.get_bind() | ||
inspector = Inspector.from_engine(conn) | ||
columns = [col["name"] for col in inspector.get_columns("trees")] | ||
if "quota_ai" not in columns: | ||
op.add_column("trees", sa.Column("quota_ai", sa.Integer(), nullable=True)) | ||
if "usage_ai" not in columns: | ||
op.add_column("trees", sa.Column("usage_ai", sa.Integer(), nullable=True)) | ||
if "min_role_ai" not in columns: | ||
op.add_column("trees", sa.Column("min_role_ai", sa.Integer(), nullable=True)) | ||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade(): | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.drop_column("trees", "usage_ai") | ||
op.drop_column("trees", "quota_ai") | ||
op.drop_column("trees", "min_role_ai") | ||
# ### end Alembic commands ### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
"""Functions for working with large language models (LLMs).""" | ||
|
||
from __future__ import annotations | ||
|
||
from flask import current_app | ||
from openai import OpenAI, RateLimitError, APIError | ||
|
||
from ..search import get_search_indexer | ||
from ..util import abort_with_message, get_logger | ||
|
||
|
||
def get_client(config: dict) -> OpenAI: | ||
"""Get an OpenAI client instance.""" | ||
if not config.get("LLM_MODEL"): | ||
raise ValueError("No LLM specified") | ||
return OpenAI(base_url=config.get("LLM_BASE_URL")) | ||
|
||
|
||
def answer_prompt(prompt: str, system_prompt: str, config: dict | None = None) -> str: | ||
"""Answer a question given a system prompt.""" | ||
if not config: | ||
if current_app: | ||
config = current_app.config | ||
else: | ||
raise ValueError("Outside of the app context, config needs to be provided") | ||
|
||
messages = [] | ||
|
||
if system_prompt: | ||
messages.append( | ||
{ | ||
"role": "system", | ||
"content": str(system_prompt), | ||
} | ||
) | ||
|
||
messages.append( | ||
{ | ||
"role": "user", | ||
"content": str(prompt), | ||
} | ||
) | ||
|
||
client = get_client(config=config) | ||
model = config.get("LLM_MODEL") | ||
|
||
try: | ||
response = client.chat.completions.create( | ||
messages=messages, | ||
model=model, | ||
) | ||
except RateLimitError: | ||
abort_with_message(500, "Chat API rate limit exceeded.") | ||
except APIError: | ||
abort_with_message(500, "Chat API error encountered.") | ||
except Exception: | ||
abort_with_message(500, "Unexpected error.") | ||
|
||
try: | ||
answer = response.to_dict()["choices"][0]["message"]["content"] | ||
except (KeyError, IndexError): | ||
abort_with_message(500, "Error parsing chat API response.") | ||
|
||
return answer | ||
|
||
|
||
def answer_prompt_with_context(prompt: str, context: str) -> str: | ||
|
||
system_prompt = ( | ||
"You are an assistant for answering questions about a user's family history. " | ||
"Use the following pieces of context retrieved from a genealogical database " | ||
"to answer the question. " | ||
"If you don't know the answer, just say that you don't know. " | ||
"Use three sentences maximum and keep the answer concise." | ||
"In your answer, preserve relative Markdown links." | ||
) | ||
|
||
system_prompt = f"""{system_prompt}\n\n{context}""" | ||
return answer_prompt(prompt=prompt, system_prompt=system_prompt) | ||
|
||
|
||
def contextualize_prompt(prompt: str, context: str) -> str: | ||
|
||
system_prompt = ( | ||
"Given a chat history and the latest user question " | ||
"which might reference context in the chat history, " | ||
"formulate a standalone question which can be understood " | ||
"without the chat history. Do NOT answer the question, " | ||
"just reformulate it if needed and otherwise return it as is." | ||
) | ||
|
||
system_prompt = f"""{system_prompt}\n\n{context}""" | ||
|
||
return answer_prompt(prompt=prompt, system_prompt=system_prompt) | ||
|
||
|
||
def retrieve(tree: str, prompt: str, include_private: bool, num_results: int = 10): | ||
searcher = get_search_indexer(tree, semantic=True) | ||
total, hits = searcher.search( | ||
query=prompt, | ||
page=1, | ||
pagesize=num_results, | ||
include_private=include_private, | ||
include_content=True, | ||
) | ||
return [hit["content"] for hit in hits] | ||
|
||
|
||
def answer_prompt_retrieve( | ||
prompt: str, | ||
tree: str, | ||
include_private: bool, | ||
history: list | None = None, | ||
) -> str: | ||
logger = get_logger() | ||
|
||
if not history: | ||
# no chat history present - we directly retrieve the context | ||
|
||
search_results = retrieve( | ||
prompt=prompt, tree=tree, include_private=include_private, num_results=20 | ||
) | ||
if not search_results: | ||
abort_with_message("Unexpected problem while retrieving context") | ||
|
||
context = "" | ||
max_length = current_app.config["LLM_MAX_CONTEXT_LENGTH"] | ||
for search_result in search_results: | ||
if len(context) + len(search_result) > max_length: | ||
break | ||
context += search_result + "\n\n" | ||
context = context.strip() | ||
|
||
logger.debug("Answering prompt '%s' with context '%s'", prompt, context) | ||
logger.debug("Context length: %s characters", len(context)) | ||
return answer_prompt_with_context(prompt=prompt, context=context) | ||
|
||
# chat history is present - we first need to call the LLM to merge the history | ||
# and the prompt into a new, standalone prompt. | ||
|
||
context = "" | ||
for message in history: | ||
if "role" not in message or "message" not in message: | ||
raise ValueError(f"Invalid message format: {message}") | ||
if message["role"].lower() in ["ai", "system", "assistant"]: | ||
context += f"*Assistant message:* {message['message']}\n\n" | ||
elif message["role"].lower() == "error": | ||
pass | ||
else: | ||
context += f"*Human message:* {message['message']}\n\n" | ||
context = context.strip() | ||
|
||
logger.debug("Contextualizing prompt '%s' with context '%s'", prompt, context) | ||
new_prompt = contextualize_prompt(prompt=prompt, context=context) | ||
logger.debug("New prompt: '%s'", new_prompt) | ||
|
||
# we can now feed the standalone prompt into the same function but without history. | ||
return answer_prompt_retrieve( | ||
prompt=new_prompt, tree=tree, include_private=include_private | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.