diff --git a/api/server.py b/api/server.py index 02c671f..1eac428 100644 --- a/api/server.py +++ b/api/server.py @@ -10,12 +10,18 @@ make server """ +import json as json_parser import os +import random as random_module from operator import itemgetter from typing import List, Tuple -from fastapi import FastAPI +from elevenlabs.client import ElevenLabs +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates from langchain.prompts import ChatPromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import format_document @@ -35,6 +41,8 @@ EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL', None) LLM_BASE_URL = os.getenv('LLM_BASE_URL', None) LANGCHAIN_TRACING_V2 = os.getenv('LANGCHAIN_TRACING_V2', 'false').lower() == 'true' +DOCUMENT_IDS = [] +NUMBER_OF_RESULTS = int(os.getenv('NUMBER_OF_RESULTS', '6')) _TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. @@ -51,8 +59,14 @@ Please use your general knowledge if the question includes the title of a film, tv show, videogame, artwork, or object. -Please include the ID of the collection item in response if you find relevant information. -Also include a link at the bottom with this format if you find relevant information: 'https://url.acmi.net.au/w/' +Please include the ID (not the ACMI ID) of the collection item in response if you find relevant +information. Also include a link at the bottom with this format if you find relevant information: +' + +Please take on the personality of ACMI museum CEO Seb Chan and reply in a form suitable +to be spoken by a test-to-speech engine. + +Please output the response in valid HTML format. Question: {question} """ @@ -79,6 +93,16 @@ def _format_chat_history(chat_history: List[Tuple]) -> str: return buffer +def get_random_documents( + document_search, + number_of_documents: int = NUMBER_OF_RESULTS, +) -> List[any]: + """Fetch random documents from the vector store.""" + random_ids = random_module.sample(DOCUMENT_IDS, min(number_of_documents, len(DOCUMENT_IDS))) + random_documents = document_search.get(random_ids)['documents'] + return random_documents + + if MODEL.startswith('gpt'): llm = ChatOpenAI(temperature=0, model=MODEL) embeddings = OpenAIEmbeddings(model=EMBEDDINGS_MODEL or 'text-embedding-ada-002') @@ -94,7 +118,8 @@ def _format_chat_history(chat_history: List[Tuple]) -> str: embedding_function=embeddings, persist_directory=f'{DATABASE_PATH}{PERSIST_DIRECTORY}', ) -retriever = docsearch.as_retriever() +retriever = docsearch.as_retriever(search_kwargs={'k': NUMBER_OF_RESULTS}) +DOCUMENT_IDS = docsearch.get()['ids'] _inputs = RunnableMap( standalone_question=RunnablePassthrough.assign( @@ -132,6 +157,10 @@ class ChatHistory(BaseModel): # pylint: disable=too-few-public-methods description='A simple ACMI Public API chat server using Langchain\'s Runnable interfaces.', ) +# Load static assets and templates +app.mount('/static', StaticFiles(directory='api/static'), name='static') +templates = Jinja2Templates(directory='api/templates') + # Set all CORS enabled origins app.add_middleware( CORSMiddleware, @@ -150,9 +179,50 @@ class ChatHistory(BaseModel): # pylint: disable=too-few-public-methods @app.get('/') -async def root(): +async def root( + request: Request, + json: bool = True, + query: str = '', + items: str = '', + random: bool = False, +): """Returns the home view.""" - return { + + results = [] + options = [ + { + 'title': "I'm in a", + 'options': [ + ['happy', False], + ['content', False], + ['nostalgic', False], + ['melancholic', False], + ['dark', False], + ], + }, + { + 'title': 'mood looking for', + 'options': [ + ['tv shows', False], + ['films', False], + ['games', False], + ['objects', False], + ['art', False], + ], + }, + { + 'title': 'about', + 'options': [ + ['cats', False], + ['dogs', False], + ['politics', False], + ['gender', False], + ['sustainability', False], + ['space', False], + ], + }, + ] + home_json = { 'message': 'Welcome to the ACMI Collection Chat API.', 'api': sorted({route.path for route in app.routes}), 'acknowledgement': @@ -164,6 +234,81 @@ async def root(): 'photographs, film, audio recordings or text.', } + if items: + items = items.split(',') + for index, item in enumerate(items): + query += f'{options[index]["title"]} {item} ' + for option in options[index]['options']: + if option[0] == item: + option[1] = True + + if query: + results = [json_parser.loads(result.page_content) for result in retriever.invoke(query)] + + if random: + results = [json_parser.loads(result) for result in get_random_documents(docsearch)] + + if json and query: + return results + + if json: + return home_json + + return templates.TemplateResponse( + request=request, + name='index.html', + context={'query': query, 'results': results, 'options': options, 'model': MODEL}, + ) + + +@app.post('/similar') +async def similar(request: Request): + """Returns similar items from the vector database to the body string.""" + + body = await request.body() + body = body.decode('utf-8') + results = [json_parser.loads(result.page_content) for result in retriever.invoke(body)] + return results + + +@app.post('/speak') +async def speak(request: Request): + """Returns an audio stream of the body string.""" + + text_to_speech = ElevenLabs() + body = await request.body() + body = body.decode('utf-8') + return StreamingResponse(text_to_speech.generate( + text=body, + voice='Seb Chan', + model='eleven_multilingual_v2', + stream=True, + )) + + +@app.post('/summarise') +async def summarise(request: Request): + """Returns a summary of the visitor's query vs the results, including an anecdote.""" + body = await request.body() + body = body.decode('utf-8') + llm_prompt = f""" + System: You are an ACMI museum guide. Please compare the user's question to the museum + collection items in the response and provide an overall summary of how you think it did + and why as if you were talking to the user in a short one sentence form suitable for + text-to-speech as it will be converted to audio and read back to the visitor. + + Apologise if the results don't match, and provide an anecdote about the data + in one of the collection records. + + Example: . + + User's query and context: + + {body} + """ + return llm.invoke(llm_prompt).content + + if __name__ == '__main__': import uvicorn diff --git a/api/static/__init__.py b/api/static/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/static/audio/seb-are-you-looking-for-something.mp3 b/api/static/audio/seb-are-you-looking-for-something.mp3 new file mode 100644 index 0000000..c34f459 Binary files /dev/null and b/api/static/audio/seb-are-you-looking-for-something.mp3 differ diff --git a/api/static/images/acmi-logo.svg b/api/static/images/acmi-logo.svg new file mode 100644 index 0000000..49e4b29 --- /dev/null +++ b/api/static/images/acmi-logo.svg @@ -0,0 +1,34 @@ + + + +ACMI logo +The logo of the Australian Centre for the Moving Image. + + + + + + + + + + + + + + + + diff --git a/api/static/styles.css b/api/static/styles.css new file mode 100644 index 0000000..2137c45 --- /dev/null +++ b/api/static/styles.css @@ -0,0 +1,125 @@ +@font-face { + font-family: 'PxGroteskBold'; + src: url('https://acmi-fonts.s3.ap-southeast-2.amazonaws.com/PxGroteskBold.woff'); +} + +@font-face { + font-family: 'Fakt'; + src: url('https://acmi-fonts.s3.ap-southeast-2.amazonaws.com/FaktPro-Normal.woff'); +} + +body { + text-align: center; + max-width: 60rem; + margin: 0 auto; + font-family: 'Fakt', sans-serif; +} + +label, select, input { + font-size: 1.8rem; +} + +select { + width: 6rem; + position: relative; + top: -0.3rem; + margin: 0.5rem; + font-size: 1.4rem; +} + +input { + border-radius: 0.3rem; + padding: 0.5rem; +} + +input[type="submit"] { + border-radius: 1rem; + padding: 0.5rem 1.5rem; +} + +#query { + width: 90%; +} + +a { + color: inherit; + text-decoration: none; +} + +a:hover { + text-decoration: underline; +} + +audio { + padding: 1rem; +} + +h1 { + font-family: 'PxGroteskBold', sans-serif; + font-size: 3rem; + letter-spacing: -0.1rem; +} + +h2, h3 { + font-family: 'PxGroteskBold', sans-serif; + font-size: 1.1rem; + font-weight: 400; +} + +h3 { + font-size: 2rem; + margin-bottom: 0.8rem; + line-height: 1.9rem; +} + +ol, ul { + margin: 0; + padding: 0; +} + +li { + display: inline-block; + margin: 1rem 1.5rem; + width: 42%; + vertical-align: top; +} + +li p { + text-align: left; +} + +h3 + p, h3 + p + p { + opacity: 0.6; + margin: 0.3rem; + text-align: center; +} + +#chats li { + background-color: lightgray; + padding: 2rem 2rem; + border-radius: 2rem; + line-height: 1.5rem; +} + +.logo { + max-width: 10rem; + margin-top: 10rem; + opacity: 0.1; +} + +.logo:hover { + opacity: 1; +} + +.disclaimer { + opacity: 0.4; + font-size: 0.8rem; + padding: 1rem; +} + +@media (max-width: 1224px) { + li { + width: 90%; + margin: 0.5rem; + } +} diff --git a/api/templates/index.html b/api/templates/index.html new file mode 100644 index 0000000..9d2e554 --- /dev/null +++ b/api/templates/index.html @@ -0,0 +1,142 @@ + + + + + + ACMI collection chat + + + + + +

Show me something

+ +
+ + +
+ +
    +
+ + +

or

+ +
+ {% for option in options %} + + + {% endfor %} + + +
+ +

or

+ +
+ + + +
+ +
    + {% for result in results %} +
  • +

    {{ result.title }}

    +

    {{ result.creator_credit }}

    +

    {{ result.headline_credit }}

    + {{ result.brief_description|truncate(200)|safe }} +
  • + {% endfor %} +
+ + +

All answers are generated by a large language model, {{ model }}, so please use them with caution.

+ + + + diff --git a/config.tmpl.env b/config.tmpl.env index b7ecaaf..299b362 100644 --- a/config.tmpl.env +++ b/config.tmpl.env @@ -3,3 +3,4 @@ LANGCHAIN_API_KEY=your-langchain-api-key MODEL= LLM_BASE_URL= EMBEDDINGS_MODEL= +ELEVEN_API_KEY=your-elevenlabs-api-key diff --git a/development/docker-compose.yml b/development/docker-compose.yml index 10a5cc3..7d52017 100644 --- a/development/docker-compose.yml +++ b/development/docker-compose.yml @@ -1,5 +1,3 @@ -version: "3" - services: chat: build: diff --git a/requirements/base.txt b/requirements/base.txt index b93c6c6..c7cd12d 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -10,3 +10,9 @@ furl # Pin pydantic to see /docs pydantic==1.10.13 + +# Jinja2 for frontend templates +jinja2 + +# ElevenLabs for text-to-speech +elevenlabs diff --git a/tests/tests.py b/tests/tests.py index 7529cc1..6aa89ee 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,3 +1,5 @@ +import json +from unittest.mock import MagicMock, patch from fastapi.testclient import TestClient from api.server import app @@ -17,3 +19,60 @@ def test_root(): assert '/docs' in response.json()['api'] assert '/invoke' in response.json()['api'] assert '/playground/{file_path:path}' in response.json()['api'] + assert '/similar' in response.json()['api'] + assert '/speak' in response.json()['api'] + assert '/summarise' in response.json()['api'] + + response = client.get('/?json=false') + assert response.status_code == 200 + assert 'Show me something' in response.content.decode('utf-8') + + +@patch('api.server.retriever') +def test_similar(mock_retriever): + """ + Test the /similar endpoint returns expected data. + """ + mock_retriever.invoke.return_value = [ + MagicMock(page_content=json.dumps({'key': 'value'})) + ] + client = TestClient(app) + response = client.get('/similar') + assert response.status_code == 405 + + response = client.post('/similar', data={'query': 'ghosts'}) + assert response.status_code == 200 + assert response.json() + + +@patch('api.server.ElevenLabs.generate') +def test_speak(mock_generate): + """ + Test the /speak endpoint returns expected data. + """ + mock_generate.return_value = iter([b'audio data']) + + client = TestClient(app) + response = client.get('/speak') + assert response.status_code == 405 + + response = client.post('/speak', data={'text': 'Oh hello!'}) + assert response.status_code == 200 + assert response.content == b'audio data' + + +@patch('api.server.llm') +def test_summarise(mock_llm): + """ + Test the /summarise endpoint returns expected data. + """ + mock_llm.invoke.return_value = MagicMock(content=b'An excellent summary.') + + client = TestClient(app) + response = client.get('/summarise') + assert response.status_code == 405 + + response = client.post('/summarise', data={'text': 'Oh hello!'}) + assert response.status_code == 200 + assert 'ACMI museum guide' in mock_llm.invoke.call_args[0][0] + assert 'text=Oh+hello' in mock_llm.invoke.call_args[0][0]