-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Separate endpoints for vector and graph-only options * Vector chain updated to create a vector index if none is already present in the database * Dependencies updated * LLM var moved from chains to a config file * Changelog added * README updated with notes on data requirements for vector support
- Loading branch information
Showing
9 changed files
with
666 additions
and
463 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Changelog | ||
|
||
All notable changes to this project will be documented in this file. | ||
|
||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), | ||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | ||
|
||
## [0.1.2] - 2024-07-27 | ||
|
||
### Added | ||
|
||
- Separate endpoints for vector and graph-only options | ||
|
||
### Changed | ||
|
||
- Vector chain updated to create a vector index if none is already present in the database | ||
- Mode option in POST payload, now only requires the 'message' key-value | ||
- Dependencies updated | ||
|
||
## [0.1.1] - 2024-06-05 | ||
|
||
### Added | ||
|
||
- CORS middleware | ||
- Neo4j exception middleware | ||
|
||
### Changed | ||
|
||
- Replaced deprecated LLMChain implementation | ||
- Vector chain simplified to use RetrievalQA chain | ||
- Dependencies updated | ||
|
||
## [0.1.0] - 2024-04-05 | ||
|
||
### Added | ||
|
||
- Initial release. | ||
- Core functionality implemented, including: | ||
- FastAPI wrapper | ||
- Vector chain example | ||
- Graph chain example | ||
- Simple Agent example that aggregates results of the Vector and Graph retrievers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import os | ||
|
||
# Neo4j Credentials | ||
NEO4J_URI = os.getenv("NEO4J_URI") | ||
NEO4J_DATABASE = os.getenv("NEO4J_DATABASE") | ||
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME") | ||
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") | ||
|
||
# ================== | ||
# Change models here | ||
# ================== | ||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | ||
|
||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | ||
LLM = ChatOpenAI(temperature=0, openai_api_key=OPENAI_API_KEY) | ||
EMBEDDINGS = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY) | ||
# ================== |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,113 +1,98 @@ | ||
from __future__ import annotations | ||
from typing import Union | ||
from app.graph_chain import graph_chain, CYPHER_GENERATION_PROMPT | ||
from app.vector_chain import vector_chain, VECTOR_PROMPT | ||
from app.simple_agent import simple_agent_chain | ||
from fastapi import FastAPI, Request, Response | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from starlette.middleware.base import BaseHTTPMiddleware | ||
from fastapi import FastAPI | ||
from typing import Union, Optional | ||
from pydantic import BaseModel, Field | ||
from neo4j import exceptions | ||
import logging | ||
|
||
|
||
class ApiChatPostRequest(BaseModel): | ||
message: str = Field(..., description="The chat message to send") | ||
mode: str = Field( | ||
"agent", | ||
description='The mode of the chat message. Current options are: "vector", "graph", "agent". Default is "agent"', | ||
) | ||
|
||
|
||
class ApiChatPostResponse(BaseModel): | ||
response: str | ||
|
||
|
||
class Neo4jExceptionMiddleware(BaseHTTPMiddleware): | ||
async def dispatch(self, request: Request, call_next): | ||
try: | ||
response = await call_next(request) | ||
return response | ||
except exceptions.AuthError as e: | ||
msg = f"Neo4j Authentication Error: {e}" | ||
logging.warning(msg) | ||
return Response(content=msg, status_code=400, media_type="text/plain") | ||
except exceptions.ServiceUnavailable as e: | ||
msg = f"Neo4j Database Unavailable Error: {e}" | ||
logging.warning(msg) | ||
return Response(content=msg, status_code=400, media_type="text/plain") | ||
except Exception as e: | ||
msg = f"Neo4j Uncaught Exception: {e}" | ||
logging.error(msg) | ||
return Response(content=msg, status_code=400, media_type="text/plain") | ||
|
||
|
||
# Allowed CORS origins | ||
origins = [ | ||
"http://127.0.0.1:8000", # Alternative localhost address | ||
"http://localhost:8000", | ||
] | ||
message: Optional[str] = Field(None, description="The chat message response") | ||
|
||
|
||
app = FastAPI() | ||
|
||
# Add CORS middleware to allow cross-origin requests | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=origins, | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
|
||
@app.post( | ||
"/api/chat", | ||
response_model=None, | ||
responses={"201": {"model": ApiChatPostResponse}}, | ||
tags=["chat"], | ||
description="Endpoint utilizing a simple agent to composite responses from the Vector and Graph chains interfacing with a Neo4j instance.", | ||
) | ||
# Add Neo4j exception handling middleware | ||
app.add_middleware(Neo4jExceptionMiddleware) | ||
def send_chat_message(body: ApiChatPostRequest) -> Union[None, ApiChatPostResponse]: | ||
""" | ||
Send a chat message | ||
""" | ||
|
||
question = body.message | ||
|
||
v_response = vector_chain().invoke( | ||
{"question": question}, prompt=VECTOR_PROMPT, return_only_outputs=True | ||
) | ||
g_response = graph_chain().invoke( | ||
{"query": question}, prompt=CYPHER_GENERATION_PROMPT, return_only_outputs=True | ||
) | ||
|
||
# Return an answer from a chain that composites both the Vector and Graph responses | ||
response = simple_agent_chain().invoke( | ||
{ | ||
"question": question, | ||
"vector_result": v_response, | ||
"graph_result": g_response, | ||
} | ||
) | ||
|
||
return f"{response}", 200 | ||
|
||
|
||
@app.post( | ||
"/api/chat", | ||
"/api/chat/vector", | ||
response_model=None, | ||
responses={"201": {"model": ApiChatPostResponse}}, | ||
tags=["chat"], | ||
description="Endpoint for utilizing only vector index for querying Neo4j instance.", | ||
) | ||
def send_chat_vector_message( | ||
body: ApiChatPostRequest, | ||
) -> Union[None, ApiChatPostResponse]: | ||
""" | ||
Send a chat message | ||
""" | ||
|
||
question = body.message | ||
|
||
response = vector_chain().invoke( | ||
{"question": question}, prompt=VECTOR_PROMPT, return_only_outputs=True | ||
) | ||
|
||
return f"{response}", 200 | ||
|
||
|
||
@app.post( | ||
"/api/chat/graph", | ||
response_model=None, | ||
responses={"201": {"model": ApiChatPostResponse}}, | ||
tags=["chat"], | ||
description="Endpoint using only Text2Cypher for querying with Neo4j instance.", | ||
) | ||
async def send_chat_message(body: ApiChatPostRequest): | ||
def send_chat_graph_message( | ||
body: ApiChatPostRequest, | ||
) -> Union[None, ApiChatPostResponse]: | ||
""" | ||
Send a chat message | ||
""" | ||
|
||
question = body.message | ||
|
||
# Simple exception check. See https://neo4j.com/docs/api/python-driver/current/api.html#errors for full set of driver exceptions | ||
|
||
if body.mode == "vector": | ||
# Return only the Vector answer | ||
v_response = vector_chain().invoke( | ||
{"query": question}, prompt=VECTOR_PROMPT, return_only_outputs=True | ||
) | ||
response = v_response | ||
elif body.mode == "graph": | ||
# Return only the Graph (text2Cypher) answer | ||
g_response = graph_chain().invoke( | ||
{"query": question}, | ||
prompt=CYPHER_GENERATION_PROMPT, | ||
return_only_outputs=True, | ||
) | ||
response = g_response["result"] | ||
else: | ||
# Return both vector + graph answers | ||
v_response = vector_chain().invoke( | ||
{"query": question}, prompt=VECTOR_PROMPT, return_only_outputs=True | ||
) | ||
g_response = graph_chain().invoke( | ||
{"query": question}, | ||
prompt=CYPHER_GENERATION_PROMPT, | ||
return_only_outputs=True, | ||
)["result"] | ||
|
||
# Synthesize a composite of both the Vector and Graph responses | ||
response = simple_agent_chain().invoke( | ||
{ | ||
"question": question, | ||
"vector_result": v_response, | ||
"graph_result": g_response, | ||
} | ||
) | ||
|
||
return response, 200 | ||
response = graph_chain().invoke( | ||
{"query": question}, prompt=CYPHER_GENERATION_PROMPT, return_only_outputs=True | ||
) | ||
|
||
return f"{response}", 200 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.