Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#11 Add experimental frontend #12

Merged
merged 8 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 150 additions & 5 deletions api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@
make server
"""

import json as json_parser
import os
import random as random_module
from operator import itemgetter
from typing import List, Tuple

from fastapi import FastAPI
from elevenlabs.client import ElevenLabs
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from langchain.prompts import ChatPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import format_document
Expand All @@ -35,6 +41,8 @@
EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL', None)
LLM_BASE_URL = os.getenv('LLM_BASE_URL', None)
LANGCHAIN_TRACING_V2 = os.getenv('LANGCHAIN_TRACING_V2', 'false').lower() == 'true'
DOCUMENT_IDS = []
NUMBER_OF_RESULTS = int(os.getenv('NUMBER_OF_RESULTS', '6'))

_TEMPLATE = """Given the following conversation and a follow up question, rephrase the
follow up question to be a standalone question, in its original language.
Expand All @@ -52,7 +60,13 @@
tv show, videogame, artwork, or object.

Please include the ID of the collection item in response if you find relevant information.
Also include a link at the bottom with this format if you find relevant information: 'https://url.acmi.net.au/w/<ID>'
Also include a link at the bottom with this format if you find relevant information:
<a href="https://url.acmi.net.au/w/<ID>"><ID></a>'

Please take on the personality of ACMI museum CEO Seb Chan and reply in a form suitable
to be spoken by a test-to-speech engine.

Please output the response in valid HTML format.

Question: {question}
"""
Expand All @@ -79,6 +93,16 @@ def _format_chat_history(chat_history: List[Tuple]) -> str:
return buffer


def get_random_documents(
document_search,
number_of_documents: int = NUMBER_OF_RESULTS,
) -> List[any]:
"""Fetch random documents from the vector store."""
random_ids = random_module.sample(DOCUMENT_IDS, min(number_of_documents, len(DOCUMENT_IDS)))
random_documents = document_search.get(random_ids)['documents']
return random_documents


if MODEL.startswith('gpt'):
llm = ChatOpenAI(temperature=0, model=MODEL)
embeddings = OpenAIEmbeddings(model=EMBEDDINGS_MODEL or 'text-embedding-ada-002')
Expand All @@ -94,7 +118,8 @@ def _format_chat_history(chat_history: List[Tuple]) -> str:
embedding_function=embeddings,
persist_directory=f'{DATABASE_PATH}{PERSIST_DIRECTORY}',
)
retriever = docsearch.as_retriever()
retriever = docsearch.as_retriever(search_kwargs={'k': NUMBER_OF_RESULTS})
DOCUMENT_IDS = docsearch.get()['ids']

_inputs = RunnableMap(
standalone_question=RunnablePassthrough.assign(
Expand Down Expand Up @@ -132,6 +157,10 @@ class ChatHistory(BaseModel): # pylint: disable=too-few-public-methods
description='A simple ACMI Public API chat server using Langchain\'s Runnable interfaces.',
)

# Load static assets and templates
app.mount('/static', StaticFiles(directory='api/static'), name='static')
templates = Jinja2Templates(directory='api/templates')

# Set all CORS enabled origins
app.add_middleware(
CORSMiddleware,
Expand All @@ -150,9 +179,50 @@ class ChatHistory(BaseModel): # pylint: disable=too-few-public-methods


@app.get('/')
async def root():
async def root(
request: Request,
json: bool = True,
query: str = '',
items: str = '',
random: bool = False,
):
"""Returns the home view."""
return {

results = []
options = [
{
'title': "I'm in a",
'options': [
['happy', False],
['content', False],
['nostalgic', False],
['melancholic', False],
['dark', False],
],
},
{
'title': 'mood looking for',
'options': [
['tv shows', False],
['films', False],
['games', False],
['objects', False],
['art', False],
],
},
{
'title': 'about',
'options': [
['cats', False],
['dogs', False],
['politics', False],
['gender', False],
['sustainability', False],
['space', False],
],
},
]
home_json = {
'message': 'Welcome to the ACMI Collection Chat API.',
'api': sorted({route.path for route in app.routes}),
'acknowledgement':
Expand All @@ -164,6 +234,81 @@ async def root():
'photographs, film, audio recordings or text.',
}

if items:
items = items.split(',')
for index, item in enumerate(items):
query += f'{options[index]["title"]} {item} '
for option in options[index]['options']:
if option[0] == item:
option[1] = True

if query:
results = [json_parser.loads(result.page_content) for result in retriever.invoke(query)]

if random:
results = [json_parser.loads(result) for result in get_random_documents(docsearch)]

if json and query:
return results

if json:
return home_json

return templates.TemplateResponse(
request=request,
name='index.html',
context={'query': query, 'results': results, 'options': options, 'model': MODEL},
)


@app.post('/similar')
async def similar(request: Request):
"""Returns similar items from the vector database to the body string."""

body = await request.body()
body = body.decode('utf-8')
results = [json_parser.loads(result.page_content) for result in retriever.invoke(body)]
return results


@app.post('/speak')
async def speak(request: Request):
"""Returns an audio stream of the body string."""

text_to_speech = ElevenLabs()
body = await request.body()
body = body.decode('utf-8')
return StreamingResponse(text_to_speech.generate(
text=body,
voice='Seb Chan',
model='eleven_multilingual_v2',
stream=True,
))


@app.post('/summarise')
async def summarise(request: Request):
"""Returns a summary of the visitor's query vs the results, including an anecdote."""
body = await request.body()
body = body.decode('utf-8')
llm_prompt = f"""
System: You are an ACMI museum guide. Please compare the user's question to the museum
collection items in the response and provide an overall summary of how you think it did
and why as if you were talking to the user in a short one sentence form suitable for
text-to-speech as it will be converted to audio and read back to the visitor.

Apologise if the results don't match, and provide an anecdote about the data
in one of the collection records.

Example: <summary>. <anecdote>

User's query and context:

{body}
"""
return llm.invoke(llm_prompt).content


if __name__ == '__main__':
import uvicorn

Expand Down
Empty file added api/static/__init__.py
Empty file.
Binary file not shown.
34 changes: 34 additions & 0 deletions api/static/images/acmi-logo.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
125 changes: 125 additions & 0 deletions api/static/styles.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
@font-face {
font-family: 'PxGroteskBold';
src: url('https://acmi-fonts.s3.ap-southeast-2.amazonaws.com/PxGroteskBold.woff');
}

@font-face {
font-family: 'Fakt';
src: url('https://acmi-fonts.s3.ap-southeast-2.amazonaws.com/FaktPro-Normal.woff');
}

body {
text-align: center;
max-width: 60rem;
margin: 0 auto;
font-family: 'Fakt', sans-serif;
}

label, select, input {
font-size: 1.8rem;
}

select {
width: 6rem;
position: relative;
top: -0.3rem;
margin: 0.5rem;
font-size: 1.4rem;
}

input {
border-radius: 0.3rem;
padding: 0.5rem;
}

input[type="submit"] {
border-radius: 1rem;
padding: 0.5rem 1.5rem;
}

#query {
width: 90%;
}

a {
color: inherit;
text-decoration: none;
}

a:hover {
text-decoration: underline;
}

audio {
padding: 1rem;
}

h1 {
font-family: 'PxGroteskBold', sans-serif;
font-size: 3rem;
letter-spacing: -0.1rem;
}

h2, h3 {
font-family: 'PxGroteskBold', sans-serif;
font-size: 1.1rem;
font-weight: 400;
}

h3 {
font-size: 2rem;
margin-bottom: 0.8rem;
line-height: 1.9rem;
}

ol, ul {
margin: 0;
padding: 0;
}

li {
display: inline-block;
margin: 1rem 1.5rem;
width: 42%;
vertical-align: top;
}

li p {
text-align: left;
}

h3 + p, h3 + p + p {
opacity: 0.6;
margin: 0.3rem;
text-align: center;
}

#chats li {
background-color: lightgray;
padding: 2rem 2rem;
border-radius: 2rem;
line-height: 1.5rem;
}

.logo {
max-width: 10rem;
margin-top: 10rem;
opacity: 0.1;
}

.logo:hover {
opacity: 1;
}

.disclaimer {
opacity: 0.4;
font-size: 0.8rem;
padding: 1rem;
}

@media (max-width: 1224px) {
li {
width: 90%;
margin: 0.5rem;
}
}
Loading
Loading