From d5e0a0f3654862bed12d2f9775edfd9c67771e09 Mon Sep 17 00:00:00 2001 From: Yee Kit Date: Tue, 9 Apr 2024 16:28:16 +0800 Subject: [PATCH] V.0.2.1 (#25) * Fixed 'Nonetype' is not subscriptable error * Fixed not getting secure next-auth cookies * Updated search to comply with document set selection * Updated run.py to exit after CREATE_VECTOR_STORE * Updated token limit for memory * Bugfix in checking for wrong value types * Better error handling --- backend/backend/app/api/routers/chat.py | 2 +- backend/backend/app/api/routers/query.py | 10 +++-- backend/backend/app/api/routers/search.py | 11 +++--- backend/backend/app/utils/auth.py | 20 ++++++---- backend/backend/app/utils/index.py | 2 +- backend/backend/run.py | 6 +-- frontend/app/api/status/route.ts | 39 ++++++++++++------- frontend/app/components/header.tsx | 4 +- frontend/app/components/query-section.tsx | 15 +++++-- frontend/app/components/search-section.tsx | 2 +- .../components/ui/search/search-results.tsx | 11 +++--- .../app/components/ui/search/useSearch.tsx | 6 +-- frontend/auth.ts | 8 ++-- frontend/middleware.ts | 2 +- 14 files changed, 84 insertions(+), 54 deletions(-) diff --git a/backend/backend/app/api/routers/chat.py b/backend/backend/app/api/routers/chat.py index 17cc4ed..adad418 100644 --- a/backend/backend/app/api/routers/chat.py +++ b/backend/backend/app/api/routers/chat.py @@ -110,7 +110,7 @@ async def chat( memory = ChatMemoryBuffer.from_defaults( chat_history=messages, - token_limit=3900, + token_limit=4096, ) logger.info(f"Memory: {memory.get()}") diff --git a/backend/backend/app/api/routers/query.py b/backend/backend/app/api/routers/query.py index 3386c1f..b9772b9 100644 --- a/backend/backend/app/api/routers/query.py +++ b/backend/backend/app/api/routers/query.py @@ -4,7 +4,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import StreamingResponse from fastapi.websockets import WebSocketDisconnect -from llama_index import VectorStoreIndex from llama_index.llms.types import MessageRole from pydantic import BaseModel @@ -28,6 +27,7 @@ class _Message(BaseModel): class _ChatData(BaseModel): messages: List[_Message] + document: str @r.post("") @@ -36,8 +36,13 @@ async def query( # Note: To support clients sending a JSON object using content-type "text/plain", # we need to use Depends(json_to_model(_ChatData)) here data: _ChatData = Depends(json_to_model(_ChatData)), - index: VectorStoreIndex = Depends(get_index), ): + logger = logging.getLogger("uvicorn") + # get the document set selected from the request body + document_set = data.document + logger.info(f"Document Set: {document_set}") + # get the index for the selected document set + index = get_index(collection_name=document_set) # check preconditions and get last message which is query if len(data.messages) == 0: raise HTTPException( @@ -50,7 +55,6 @@ async def query( status_code=status.HTTP_400_BAD_REQUEST, detail="Last message must be from user", ) - logger = logging.getLogger("uvicorn") logger.info(f"Query: {lastMessage}") # Query index diff --git a/backend/backend/app/api/routers/search.py b/backend/backend/app/api/routers/search.py index bf7efe1..4aaf233 100644 --- a/backend/backend/app/api/routers/search.py +++ b/backend/backend/app/api/routers/search.py @@ -2,7 +2,6 @@ import re from fastapi import APIRouter, Depends, HTTPException, Request, status -from llama_index import VectorStoreIndex from llama_index.postprocessor import SimilarityPostprocessor from llama_index.retrievers import VectorIndexRetriever @@ -22,16 +21,18 @@ @r.get("") async def search( request: Request, - index: VectorStoreIndex = Depends(get_index), query: str = None, + docSelected: str = None, ): # query = request.query_params.get("query") logger = logging.getLogger("uvicorn") - logger.info(f"Search: {query}") - if query is None: + logger.info(f"Document Set: {docSelected} | Search: {query}") + # get the index for the selected document set + index = get_index(collection_name=docSelected) + if query is None or docSelected is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="No search info provided", + detail="No search info/document set provided", ) # configure retriever diff --git a/backend/backend/app/utils/auth.py b/backend/backend/app/utils/auth.py index c39db90..c25531b 100644 --- a/backend/backend/app/utils/auth.py +++ b/backend/backend/app/utils/auth.py @@ -71,16 +71,18 @@ def get_user_from_JWT(token: str): ) payload = decodeJWT(token) - user_id = payload["sub"] - if user_id is not None: + if payload is not None: + user_id = payload["sub"] # Try to get the user from the database using the user_id response = supabase.table("users").select("*").eq("id", user_id).execute() # print(response.data) if len(response.data) == 0: return False - return True - return False + else: + return True + else: + return False async def validate_user( @@ -89,13 +91,17 @@ async def validate_user( ): try: logger = logging.getLogger("uvicorn") - # logger.debug(f"Auth Token: {auth_token} | API Key: {api_key}") + # logger.info(f"Auth Token: {auth_token} | API Key: {api_key}") if auth_token is not None or api_key is not None: # If the access token is empty, use the 'X-API-Key' from the header - if auth_token is None: + if auth_token is None or "null" in auth_token: # Access the 'X-API-Key' header directly if BACKEND_API_KEY is None: raise ValueError("Backend API key is not set in Backend Service!") + if "null" in api_key: + raise ValueError( + "Invalid API key provided in the 'X-API-Key' header!" + ) # If the 'X-API-Key' does not match the backend API key, raise an error if api_key != BACKEND_API_KEY: raise ValueError( @@ -123,7 +129,7 @@ async def validate_user( "Invalid token scheme. Please use the format 'Bearer [token]'" ) # Verify the JWT token is valid - if verify_jwt(jwtoken=jwtoken) is None: + if verify_jwt(jwtoken=jwtoken): return "Invalid token. Please provide a valid token." # Check if the user exists in the database if get_user_from_JWT(token=jwtoken): diff --git a/backend/backend/app/utils/index.py b/backend/backend/app/utils/index.py index 6316d79..d9a28dd 100644 --- a/backend/backend/app/utils/index.py +++ b/backend/backend/app/utils/index.py @@ -230,7 +230,7 @@ def load_existing_index(collection_name="PSSCOC"): logger.info(f"Indexing [{collection_name}] vector store...") vector_store._collection.create_index() logger.info(f"Finished indexing [{collection_name}] vector store") - logger.info(vector_store._collection.name) + # logger.info(f"Collection Name: {vector_store._collection.name}") index = VectorStoreIndex.from_vector_store(vector_store=vector_store) logger.info(f"Finished loading [{collection_name}] index from Supabase") logger.info(f"Index ID: {index.index_id}") diff --git a/backend/backend/run.py b/backend/backend/run.py index df211b0..abc1237 100644 --- a/backend/backend/run.py +++ b/backend/backend/run.py @@ -31,11 +31,11 @@ def run_app(): # Create the vector store from backend.app.utils.index import create_index - logger.info("Creating vector stores first...") + logger.info("Indexing Documents & Creating Vector Stores...") create_index() - logger.info("Vector stores created successfully! Running App...") + logger.info("Vector Stores created successfully! Exiting...") # Run the app - run_app() + # run_app() else: # Run the app run_app() diff --git a/frontend/app/api/status/route.ts b/frontend/app/api/status/route.ts index e924262..ec41785 100644 --- a/frontend/app/api/status/route.ts +++ b/frontend/app/api/status/route.ts @@ -1,28 +1,39 @@ -export async function GET(request: Request) { +import { NextRequest, NextResponse } from "next/server"; + +export async function GET(request: NextRequest) { const healthcheck_api = process.env.NEXT_PUBLIC_HEALTHCHECK_API as string; // Retrieve the session token from the request headers let session = request.headers.get('Authorization'); - console.log('Status API - headers:', request.headers); + // console.log('Status API - headers:', request.headers); // Public API key let api_key = null; // If no session, use the public API key - if (!session) { + if (session === null || session === undefined || session.includes('undefined')) { + console.log('No session token found, using public API key'); api_key = process.env.BACKEND_API_KEY as string; + session = null; // Clear the session token } - const res = await fetch(healthcheck_api, { - signal: AbortSignal.timeout(5000), // Abort the request if it takes longer than 5 seconds - headers: { - 'Content-Type': 'application/json', - 'Authorization': session, - 'X-API-Key': api_key, - } as any, - }) - const data = await res.json() - - return Response.json({ data }) + try { + const res = await fetch(healthcheck_api, { + signal: AbortSignal.timeout(5000), // Abort the request if it takes longer than 5 seconds + headers: { + 'Content-Type': 'application/json', + 'Authorization': session, + 'X-API-Key': api_key, + } as any, + }) + const data = await res.json() + if (!res.ok) { + throw new Error(data.detail || 'Unknown Error'); + } + return NextResponse.json({ data }) + } catch (error : any) { + console.error(`${error}`); + return NextResponse.json({ error: error.message }, { status: 500 }) + } } \ No newline at end of file diff --git a/frontend/app/components/header.tsx b/frontend/app/components/header.tsx index 43a20de..ca9d8d3 100644 --- a/frontend/app/components/header.tsx +++ b/frontend/app/components/header.tsx @@ -53,7 +53,7 @@ export default function Header() { const signinPage = "/sign-in?callbackUrl=" + encodedPath; // Get user session for conditional rendering of user profile and logout buttons and for fetching the API status - const { data: session, status } = useSession(); + const { data: session, status } = useSession() // console.log('session:', session, 'status:', status); const supabaseAccessToken = session?.supabaseAccessToken; // Use SWR for API status fetching @@ -94,7 +94,7 @@ export default function Header() { useEffect(() => { setMounted(true); - }, []); + }, [session]); const [isMobileMenuOpen, setMobileMenuOpen] = useState(false); diff --git a/frontend/app/components/query-section.tsx b/frontend/app/components/query-section.tsx index ddd3fbf..3801965 100644 --- a/frontend/app/components/query-section.tsx +++ b/frontend/app/components/query-section.tsx @@ -4,9 +4,12 @@ import { useChat } from "ai/react"; import { ChatInput, ChatMessages } from "@/app/components/ui/chat"; import { AutofillQuestion } from "./ui/autofill-prompt"; import { useSession } from "next-auth/react"; +import { useState } from "react"; export default function QuerySection() { const { data: session } = useSession(); + const supabaseAccessToken = session?.supabaseAccessToken; + const [docSelected, setDocSelected] = useState(''); const { messages, input, @@ -15,12 +18,16 @@ export default function QuerySection() { handleInputChange, reload, stop, - } = useChat({ + } = useChat({ api: process.env.NEXT_PUBLIC_QUERY_API, - // Add the access token to the request headers headers: { - 'Authorization': `Bearer ${session?.supabaseAccessToken}`, - } + // Add the access token to the request headers + 'Authorization': `Bearer ${supabaseAccessToken}`, + }, + body: { + // Add the selected document to the request body + document: docSelected, + }, }); return ( diff --git a/frontend/app/components/search-section.tsx b/frontend/app/components/search-section.tsx index f2331a5..86a8f5a 100644 --- a/frontend/app/components/search-section.tsx +++ b/frontend/app/components/search-section.tsx @@ -19,7 +19,7 @@ const SearchSection: React.FC = () => { const handleSearchSubmit = (e: FormEvent) => { e.preventDefault(); setSearchButtonPressed(true); - handleSearch(query); + handleSearch(query, docSelected); }; return ( diff --git a/frontend/app/components/ui/search/search-results.tsx b/frontend/app/components/ui/search/search-results.tsx index 7df6380..bd71421 100644 --- a/frontend/app/components/ui/search/search-results.tsx +++ b/frontend/app/components/ui/search/search-results.tsx @@ -6,8 +6,8 @@ import 'react-toastify/dist/ReactToastify.css'; import { SearchHandler, SearchResult } from "@/app/components/ui/search/search.interface"; export default function SearchResults( - props: Pick - ) { + props: Pick +) { const [sortedResults, setSortedResults] = useState([]); const [expandedResult, setExpandedResult] = useState(null); @@ -17,11 +17,10 @@ export default function SearchResults( // Reset sortedResults when query is empty setSortedResults([]); } else if (props.query.trim() !== "" && props.searchButtonPressed) { - // if results are empty - if (props.results.length === 0) { + // if results are empty or not an array + if (!Array.isArray(props.results) || props.results.length === 0) { setSortedResults([]); - } - else { + } else { // Sort results by similarity score const sorted = props.results.slice().sort((a, b) => b.similarity_score - a.similarity_score); // Update sortedResults state diff --git a/frontend/app/components/ui/search/useSearch.tsx b/frontend/app/components/ui/search/useSearch.tsx index f79e984..28973b9 100644 --- a/frontend/app/components/ui/search/useSearch.tsx +++ b/frontend/app/components/ui/search/useSearch.tsx @@ -7,7 +7,7 @@ import { useSession } from 'next-auth/react'; interface UseSearchResult { searchResults: SearchResult[]; isLoading: boolean; - handleSearch: (query: string) => Promise; + handleSearch: (query: string, docSelected: string) => Promise; } const search_api = process.env.NEXT_PUBLIC_SEARCH_API; @@ -20,7 +20,7 @@ const useSearch = (): UseSearchResult => { // console.log('session:', session, 'status:', status); const supabaseAccessToken = session?.supabaseAccessToken; - const handleSearch = async (query: string): Promise => { + const handleSearch = async (query: string, docSelected: string): Promise => { setIsSearchButtonPressed(isSearchButtonPressed); setIsLoading(true); @@ -40,7 +40,7 @@ const useSearch = (): UseSearchResult => { setIsLoading(false); return; } - const response = await fetch(`${search_api}?query=${query}`, { + const response = await fetch(`${search_api}?query=${query}&docSelected=${docSelected}`, { signal: AbortSignal.timeout(120000), // Abort the request if it takes longer than 120 seconds // Add the access token to the request headers headers: { diff --git a/frontend/auth.ts b/frontend/auth.ts index 00e50e1..5ea9a5b 100644 --- a/frontend/auth.ts +++ b/frontend/auth.ts @@ -124,11 +124,12 @@ export const config = { token.accessToken = account.access_token token.id = profile?.sub } - return token + return token; }, async session({ session, token, user }) { // Send properties to the client, like an access_token from a provider. const signingSecret = process.env.SUPABASE_JWT_SECRET + // console.log('Signing Secret:', signingSecret); if (signingSecret) { const payload = { aud: "authenticated", @@ -137,11 +138,12 @@ export const config = { // email: user.email, role: "authenticated", } - session.supabaseAccessToken = jwt.sign(payload, signingSecret) + session.supabaseAccessToken = jwt.sign(payload, signingSecret) as string; + // console.log('New Session:', session); // session.jwt = token.jwt as string; // session.id = token.id as string; } - return session + return session; }, } diff --git a/frontend/middleware.ts b/frontend/middleware.ts index 89b9783..609da91 100644 --- a/frontend/middleware.ts +++ b/frontend/middleware.ts @@ -8,7 +8,7 @@ export const middleware = async (request: NextRequest) => { // Add callbackUrl params to the signinPage URL signinPage.searchParams.set('callbackUrl', pathname); // Retrieve the session token from the request cookies - const session = request.cookies.get('next-auth.session-token'); + const session = request.cookies.get('next-auth.session-token') || request.cookies.get('__Secure-next-auth.session-token'); if (session) { // console.log('session:', session);