Skip to content

Commit

Permalink
Dev (#3)
Browse files Browse the repository at this point in the history
* Dependencies updated

* CORs middleware added

* Neo4j exception middleware added

* Replaced deprecated LLMChain implementation

* Vector chain simplified to use RetrievalQA chain
  • Loading branch information
jalakoo authored Jun 5, 2024
1 parent c1dbc32 commit 31f11db
Show file tree
Hide file tree
Showing 7 changed files with 1,244 additions and 586 deletions.
36 changes: 28 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
# Neo4j LangChain Starter Kit
This kit provides a simple [FastAPI](https://fastapi.tiangolo.com/) backend service connected to [OpenAI](https://platform.openai.com/docs/overview) and [Neo4j](https://neo4j.com/developer/) for powering GenAI projects. The Neo4j interface leverages both [Vector Indexes](https://python.langchain.com/docs/integrations/vectorstores/neo4jvector) and [Text2Cypher](https://python.langchain.com/docs/use_cases/graph/integrations/graph_cypher_qa) chains to provide more accurate results.

![alt text](https://res.cloudinary.com/dk0tizgdn/image/upload/v1711042573/langchain_starter_kit_sample_jgvnfb.gif "Testing Neo4j LangChain Starter Kit")
This kit provides a simple [FastAPI](https://fastapi.tiangolo.com/) backend service connected to [OpenAI](https://platform.openai.com/docs/overview) and [Neo4j](https://neo4j.com/developer/) for powering GenAI projects. The Neo4j interface leverages both [Vector Indexes](https://python.langchain.com/docs/integrations/vectorstores/neo4jvector) and [Text2Cypher](https://python.langchain.com/docs/use_cases/graph/integrations/graph_cypher_qa) chains to provide more accurate results.

![alt text](https://res.cloudinary.com/dk0tizgdn/image/upload/v1711042573/langchain_starter_kit_sample_jgvnfb.gif "Testing Neo4j LangChain Starter Kit")

## Requirements

- [Poetry](https://python-poetry.org/) for virtual enviroment management
- [LangChain](https://python.langchain.com/docs/get_started/introduction)
- An [OpenAI API Key](https://openai.com/blog/openai-api)
- A running [local](https://neo4j.com/download/) or [cloud](https://neo4j.com/cloud/platform/aura-graph-database/) Neo4j database


## Usage

Add a .env file to the root folder with the following keys and your own credentials (or these included public access only creds):

```
NEO4J_URI=neo4j+ssc://9fcf58c6.databases.neo4j.io
NEO4J_DATABASE=neo4j
NEO4J_USERNAME=public
NEO4J_PASSWORD=read_only
OPENAI_API_KEY=<your_openai_key_here>
```

Then run: `poetry run uvicorn app.server:app --reload --port=8000 `

Or add env variables at runtime:

```
NEO4J_URI=neo4j+ssc://9fcf58c6.databases.neo4j.io \
NEO4J_DATABASE=neo4j \
Expand All @@ -21,24 +36,26 @@ OPENAI_API_KEY=<add_your_openai_key_here> \
poetry run uvicorn app.server:app --reload --port=8000 --log-config=log_conf.yaml
```

*NOTE* the above Neo4j credentials are for read-only access to a hosted sample dataset. Your own OpenAI api key will be needed to run this server.

*NOTE* the `NEO4J_URI` value can use either the neo4j or [bolt](https://neo4j.com/docs/bolt/current/bolt/) uri scheme. For more details on which to use, see this [example](https://neo4j.com/docs/driver-manual/4.0/client-applications/#driver-configuration-examples)
_NOTE_ the above Neo4j credentials are for read-only access to a hosted sample dataset. Your own OpenAI api key will be needed to run this server.

_NOTE_ the `NEO4J_URI` value can use either the neo4j or [bolt](https://neo4j.com/docs/bolt/current/bolt/) uri scheme. For more details on which to use, see this [example](https://neo4j.com/docs/driver-manual/4.0/client-applications/#driver-configuration-examples)

A FastAPI server should now be running on your local port 8000/api/chat.

## Custom Database Setup

If you would like to load your own instance with a subset of this information. Add your own OpenAI key to the Cypher code in the [edgar_import.cypher](edgar_import.cypher) file and run it in your instance's [Neo4j browser](https://neo4j.com/docs/browser-manual/current/).

For more information on how this load script works, see [this notebook](https://github.com/neo4j-examples/sec-edgar-notebooks/blob/main/notebooks/kg-construction/1-mvg.ipynb).


## Docs

FastAPI will make endpoint information and the ability to test from a browser at http://localhost:8000/docs

## Testing

Alternatively, after the server is running, a curl command can be triggered to test the endpoint:

```
curl --location 'http://127.0.0.1:8000/api/chat' \
--header 'Content-Type: application/json' \
Expand All @@ -47,10 +64,13 @@ curl --location 'http://127.0.0.1:8000/api/chat' \
```

## Feedback

Please provide feedback and report bugs as [GitHub issues](https://github.com/neo4j-examples/langchain-starter-kit/issues)

## Contributing

Want to improve this kit? See the [contributing guide](./CONTRIBUTING.md)

## Learn More
At [Neo4j GraphAcademy](https://graphacademy.neo4j.com), we offer a wide range of courses completely free of charge, including [Neo4j & LLM Fundamentals](https://graphacademy.neo4j.com/courses/llm-fundamentals/) and [Build a Neo4j-backed Chatbot using Python](https://graphacademy.neo4j.com/courses/llm-chatbot-python/).

At [Neo4j GraphAcademy](https://graphacademy.neo4j.com), we offer a wide range of courses completely free of charge, including [Neo4j & LLM Fundamentals](https://graphacademy.neo4j.com/courses/llm-fundamentals/) and [Build a Neo4j-backed Chatbot using Python](https://graphacademy.neo4j.com/courses/llm-chatbot-python/).
17 changes: 9 additions & 8 deletions app/graph_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
)


def graph_chain() -> Runnable:

NEO4J_URI = os.getenv("NEO4J_URI")
Expand All @@ -56,20 +57,20 @@ def graph_chain() -> Runnable:
username=NEO4J_USERNAME,
password=NEO4J_PASSWORD,
database=NEO4J_DATABASE,
sanitize = True
sanitize=True,
)

graph.refresh_schema()

# Official API doc for GraphCypherQAChain at: https://api.python.langchain.com/en/latest/chains/langchain.chains.graph_qa.base.GraphQAChain.html#
graph_chain = GraphCypherQAChain.from_llm(
cypher_llm = LLM,
qa_llm = LLM,
validate_cypher= True,
cypher_llm=LLM,
qa_llm=LLM,
validate_cypher=True,
graph=graph,
verbose=True,
return_intermediate_steps = True,
return_direct = True,
verbose=True,
return_intermediate_steps=True,
# return_direct = True,
)

return graph_chain
return graph_chain
119 changes: 86 additions & 33 deletions app/server.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,113 @@
from __future__ import annotations
from typing import Union
from app.graph_chain import graph_chain, CYPHER_GENERATION_PROMPT
from app.vector_chain import vector_chain, VECTOR_PROMPT
from app.simple_agent import simple_agent_chain
from fastapi import FastAPI
from typing import Union, Optional
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel, Field
from neo4j import exceptions
import logging


class ApiChatPostRequest(BaseModel):
message: str = Field(..., description='The chat message to send')
mode: str = Field('agent', description='The mode of the chat message. Current options are: "vector", "graph", "agent". Default is "agent"')
message: str = Field(..., description="The chat message to send")
mode: str = Field(
"agent",
description='The mode of the chat message. Current options are: "vector", "graph", "agent". Default is "agent"',
)


class ApiChatPostResponse(BaseModel):
message: Optional[str] = Field(None, description='The chat message response')
response: str


class Neo4jExceptionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
try:
response = await call_next(request)
return response
except exceptions.AuthError as e:
msg = f"Neo4j Authentication Error: {e}"
logging.warning(msg)
return Response(content=msg, status_code=400, media_type="text/plain")
except exceptions.ServiceUnavailable as e:
msg = f"Neo4j Database Unavailable Error: {e}"
logging.warning(msg)
return Response(content=msg, status_code=400, media_type="text/plain")
except Exception as e:
msg = f"Neo4j Uncaught Exception: {e}"
logging.error(msg)
return Response(content=msg, status_code=400, media_type="text/plain")


# Allowed CORS origins
origins = [
"http://127.0.0.1:8000", # Alternative localhost address
"http://localhost:8000",
]

app = FastAPI()

# Add CORS middleware to allow cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add Neo4j exception handling middleware
app.add_middleware(Neo4jExceptionMiddleware)

app = FastAPI()

@app.post(
'/api/chat',
"/api/chat",
response_model=None,
responses={'201': {'model': ApiChatPostResponse}},
tags=['chat'],
responses={"201": {"model": ApiChatPostResponse}},
tags=["chat"],
)
def send_chat_message(body: ApiChatPostRequest) -> Union[None, ApiChatPostResponse]:
async def send_chat_message(body: ApiChatPostRequest):
"""
Send a chat message
"""

question = body.message

v_response = vector_chain().invoke(
{"question":question},
prompt = VECTOR_PROMPT,
return_only_outputs = True
)
g_response = graph_chain().invoke(
{"query":question},
prompt = CYPHER_GENERATION_PROMPT,
return_only_outputs = True
)

if body.mode == 'vector':
# Simple exception check. See https://neo4j.com/docs/api/python-driver/current/api.html#errors for full set of driver exceptions

if body.mode == "vector":
# Return only the Vector answer
v_response = vector_chain().invoke(
{"query": question}, prompt=VECTOR_PROMPT, return_only_outputs=True
)
response = v_response
elif body.mode == 'graph':
elif body.mode == "graph":
# Return only the Graph (text2Cypher) answer
response = g_response
g_response = graph_chain().invoke(
{"query": question},
prompt=CYPHER_GENERATION_PROMPT,
return_only_outputs=True,
)
response = g_response["result"]
else:
# Return an answer from a chain that composites both the Vector and Graph responses
response = simple_agent_chain().invoke({
"question":question,
"vector_result":v_response,
"graph_result":g_response
})["text"]

return f"{response}", 200
# Return both vector + graph answers
v_response = vector_chain().invoke(
{"query": question}, prompt=VECTOR_PROMPT, return_only_outputs=True
)
g_response = graph_chain().invoke(
{"query": question},
prompt=CYPHER_GENERATION_PROMPT,
return_only_outputs=True,
)["result"]

# Synthesize a composite of both the Vector and Graph responses
response = simple_agent_chain().invoke(
{
"question": question,
"vector_result": v_response,
"graph_result": g_response,
}
)

return response, 200
23 changes: 12 additions & 11 deletions app/simple_agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from langchain.chains import LLMChain
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import Runnable
from langchain_openai import ChatOpenAI
from langchain.chains import ConversationChain
from langchain_core.prompts import PromptTemplate
import os

def simple_agent_chain() -> Runnable:

MEMORY = ConversationBufferMemory(memory_key="agent_history", input_key='question', output_key='text', return_messages=True)
def simple_agent_chain() -> Runnable:

final_prompt = """You are a helpful question-answering agent. Your task is to analyze
and synthesize information from two sources: the top result from a similarity search
Expand All @@ -19,14 +19,15 @@ def simple_agent_chain() -> Runnable:
Structured information: {graph_result}.
"""

prompt = PromptTemplate.from_template(final_prompt)
prompt = PromptTemplate(
input_variables=["question", "vector_result", "graph_result"],
template=final_prompt,
)

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
LLM = ChatOpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)
output_parser = StrOutputParser()

simple_agent_chain = prompt | LLM | output_parser

simple_agent_chain = LLMChain(
prompt=prompt,
llm=LLM,
memory = MEMORY)

return simple_agent_chain
return simple_agent_chain
62 changes: 21 additions & 41 deletions app/vector_chain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from langchain.prompts.prompt import PromptTemplate
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.chains import RetrievalQAWithSourcesChain
from langchain_community.vectorstores import Neo4jVector
from langchain.chains import RetrievalQAWithSourcesChain, RetrievalQA
from langchain.schema.runnable import Runnable
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import logging
Expand Down Expand Up @@ -28,9 +28,10 @@
Assistant:"""

VECTOR_PROMPT = PromptTemplate(
input_variables=["input","context"], template=VECTOR_PROMPT_TEMPLATE
input_variables=["input", "context"], template=VECTOR_PROMPT_TEMPLATE
)


def vector_chain() -> Runnable:

NEO4J_URI = os.getenv("NEO4J_URI")
Expand All @@ -49,47 +50,26 @@ def vector_chain() -> Runnable:

# Neo4jVector API: https://api.python.langchain.com/en/latest/vectorstores/langchain_community.vectorstores.neo4j_vector.Neo4jVector.html#langchain_community.vectorstores.neo4j_vector.Neo4jVector

try:
logging.debug(f'Attempting to retrieve existing vector index: {index_name}...')
vector_store = Neo4jVector.from_existing_index(
embedding=EMBEDDINGS,
url=NEO4J_URI,
username=NEO4J_USERNAME,
password=NEO4J_PASSWORD,
database=NEO4J_DATABASE,
index_name=index_name,
embedding_node_property=node_property_name,
)
logging.debug(f'Using existing index: {index_name}')
except:
logging.debug(f'No existing index found. Attempting to create a new vector index named {index_name}...')
try:
vector_store = Neo4jVector.from_existing_graph(
embedding=EMBEDDINGS,
url=NEO4J_URI,
username=NEO4J_USERNAME,
password=NEO4J_PASSWORD,
database=NEO4J_DATABASE,
index_name=index_name,
node_label="Chunk",
text_node_properties=["text"],
embedding_node_property=node_property_name,
)
logging.debug(f'Created new index: {index_name}')
except Exception as e:
logging.error(f'Failed to retrieve existing or to create a Neo4jVector: {e}')

if vector_store is None:
logging.error(f'Failed to retrieve or create a Neo4jVector. Exiting.')
exit()
# try:
logging.debug(
f"Attempting to retrieve existing vector index'{index_name}' from Neo4j instance at {NEO4J_URI}..."
)
vector_store = Neo4jVector.from_existing_index(
embedding=EMBEDDINGS,
url=NEO4J_URI,
username=NEO4J_USERNAME,
password=NEO4J_PASSWORD,
database=NEO4J_DATABASE,
index_name=index_name,
embedding_node_property=node_property_name,
)
logging.debug(f"Using existing index: {index_name}")

vector_retriever = vector_store.as_retriever()

vector_chain = RetrievalQAWithSourcesChain.from_chain_type(
vector_chain = RetrievalQA.from_chain_type(
LLM,
chain_type="stuff",
chain_type="stuff",
retriever=vector_retriever,
reduce_k_below_max_tokens = True,
max_tokens_limit=2000
)
return vector_chain
return vector_chain
Loading

0 comments on commit 31f11db

Please sign in to comment.