Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Implemented double llm call to significantly reduce the number of false positives #212

Open
wants to merge 2 commits into
base: Dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 69 additions & 5 deletions RAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

load_dotenv()
MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY")
MISTRAL_SECOND_API_KEY = os.environ.get("MISTRAL_SECOND_API_KEY")
MILVUS_URI = "/app/milvus/milvus_vector.db"
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
MAX_TEXT_LENGTH = 5000
Expand All @@ -42,6 +43,57 @@ def get_embedding_model():
EMBEDDING_MODEL = SentenceTransformer(MODEL_NAME)
return EMBEDDING_MODEL

def recheck_prompt():
"""
Create a prompt template for the RAG model
Returns:
PromptTemplate: The prompt template for the RAG model
"""
# Define the prompt template
PROMPT_TEMPLATE = """
You are an AI assistant tasked with verifying if an answer is supported by the given context. Follow these steps:
- Check if the answer in the <answer> tags is directly supported by the context in <context> tags.
- If the answer is not supported by the context or if the context lacks relevant information, respond with "False."
- If the answer is fully supported by the context, respond with "True."
- Do not assume or infer any information not explicitly mentioned in the context.
- Respond with only "True" or "False" Do not provide any additional explanations or comments.
"""

prompt = ChatPromptTemplate.from_messages([
("system", PROMPT_TEMPLATE),
("human", "<answer>{input}</answer>\n\n<context>{context}</context>"),
])
print(" Rechecking prompt Created")

return prompt

def is_answer_relevant(model, answer, context):
"""
Call the Mistral LLM to check if the answer is based on the given context.
Args:
answer (str): The generated answer to verify.
context (str): The context used to generate the answer.
Returns:
str: The response from the LLM after rechecking.
"""
try:
prompt = recheck_prompt()

llm_input = prompt.format_messages(
input=answer,
context=context
)
recheck_response = model.invoke(llm_input)
print("Recheck Response: ", recheck_response)
recheck_output = recheck_response.content

if "true" in recheck_output.lower() :
return True
return False
except HTTPStatusError as e:
print(f"HTTPStatusError: {e}")
return True

def is_filtered_query(query):
patterns = [
r"\b(hi|hello|hey|hiya|howdy|greetings|yo)\b", # Common greetings
Expand Down Expand Up @@ -80,15 +132,15 @@ def query_rag(query):
return f"Hi there! I'm an AI assistant powered by {CORPUS_SOURCE}. I'm here to help with any questions you might have. How can I assist you today?", "Unknown"

# Define the model
chat_model = ChatMistralAI(model='open-mistral-7b', temperature = 0.2)
chat_model = ChatMistralAI(model='open-mistral-7b', temperature = 0)
print("Model Loaded")

# Create the prompt and components for the RAG model
prompt = create_prompt()
print("Before embedding model")
model = get_embedding_model()
print("After embedding model")
retriever = ScoreThresholdRetriever(score_threshold=0.2, k=3)
retriever = ScoreThresholdRetriever(score_threshold=0.15, k=3)
document_chain = create_stuff_documents_chain(chat_model, prompt)
query_embedding = np.array(model.encode(query), dtype=np.float32).tolist()
collection = Collection(re.sub(r'\W+', '', CORPUS_SOURCE))
Expand All @@ -103,6 +155,7 @@ def query_rag(query):
most_relevant_document = retrieved_documents[0]
source = most_relevant_document.metadata.get("source", "Unknown")
title = most_relevant_document.metadata.get("title", "Untitled").replace("\n", " ")
score = most_relevant_document.metadata.get("score", 0)

print("Most Relevant Document Retrieved")

Expand All @@ -112,9 +165,20 @@ def query_rag(query):
"context": retrieved_documents
})

print("First Response Generated", response)

# Add the source to the response if available
if isinstance(source, str) and source != "Unknown":
response += f"\n\nSource: [{title}]({source})"
if score < 0.45:
second_chat_model = ChatMistralAI(model='open-mistral-7b', temperature = 0, mistral_api_key = MISTRAL_SECOND_API_KEY)
result = is_answer_relevant(second_chat_model, response, retrieved_documents)
if result == False:
response = "I don't have enough information to answer this question."
source = "Unknown"
else:
response += f"\n\nSource: [{title}]({source})"
else:
response += f"\n\nSource: [{title}]({source})"
print("Response Generated")

return response, source
Expand All @@ -134,8 +198,8 @@ def create_prompt():
"""
# Define the prompt template
PROMPT_TEMPLATE = """
You are an AI assistant that provides answers strictly based on the provided context. Adhere to these guidelines:
- Only answer questions based on the content within the <context> tags.
You are an AI assistant that provides answers strictly based on the provided context.
- Only answer questions based on the content within the <context> tags.
- If the <context> does not contain information related to the question, respond only with: "I don't have enough information to answer this question."
- For unclear questions or questions that lack specific context, request clarification from the user.
- Provide specific, concise ansewrs. Where relevant information includes statistics or numbers, include them in the response.
Expand Down
Binary file modified milvus/milvus_vector.db
Binary file not shown.