Skip to content

Commit

Permalink
enhance multihop dataset accuracy (#62)
Browse files Browse the repository at this point in the history
* add rerank.

* update gaudi base docker image.

---------

Co-authored-by: changwangss <[email protected]>
Co-authored-by: root <[email protected]>
  • Loading branch information
3 people authored Aug 21, 2024
1 parent 7b719de commit dfc2c1e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
4 changes: 2 additions & 2 deletions docker/hpu.dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM vault.habana.ai/gaudi-docker/1.13.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.0:latest as hpu
FROM vault.habana.ai/gaudi-docker/1.16.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest as hpu

ENV LANG=en_US.UTF-8
ENV PYTHONPATH=/root:/usr/lib/habanalabs/
Expand All @@ -24,4 +24,4 @@ RUN cd /GenAIEval && \
pip install --upgrade-strategy eager optimum[habana] && \
pip list

WORKDIR /GenAIEval/
WORKDIR /GenAIEval/
22 changes: 22 additions & 0 deletions evals/evaluation/rag_eval/examples/eval_multihop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ def get_document(self, data: dict):
)
return document

def get_reranked_documents(self, query, docs, arguments):
data = {
"initial_query": query,
"retrieved_docs": [{"text": doc} for doc in docs],
"top_n": 10,
}
headers = {"Content-Type": "application/json"}

response = requests.post(arguments.reranking_endpoint, data=json.dumps(data), headers=headers)
if response.ok:
reranked_documents = response.json()["documents"]
return reranked_documents
else:
print(f"Request for retrieval failed due to {response.text}.")
return []

def get_retrieved_documents(self, query, arguments):
data = {"text": query}
headers = {"Content-Type": "application/json"}
Expand Down Expand Up @@ -77,6 +93,8 @@ def get_retrieval_metrics(self, all_queries, arguments):
continue
query = data["query"]
retrieved_documents = self.get_retrieved_documents(query, arguments)
if arguments.rerank:
retrieved_documents = self.get_reranked_documents(query, retrieved_documents, arguments)
golden_context = [each["fact"] for each in data["evidence_list"]]
test_case = {
"input": query,
Expand Down Expand Up @@ -212,6 +230,10 @@ def args_parser():
parser.add_argument(
"--retrieval_endpoint", type=str, default="http://localhost:7000/v1/retrieval", help="Service URL address."
)
parser.add_argument("--rerank", action="store_true", help="Whether to use rerank microservice.")
parser.add_argument(
"--reranking_endpoint", type=str, default="http://localhost:8000/v1/reranking", help="Service URL address."
)
parser.add_argument("--llm_endpoint", type=str, default=None, help="Service URL address.")
parser.add_argument(
"--show_progress_bar", action="store", default=True, type=bool, help="Whether to show a progress bar"
Expand Down
2 changes: 1 addition & 1 deletion evals/metrics/ragas/ragas.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
self.model = model
self.embeddings = embeddings
self.metrics = metrics
self.validated_list = ["answer_relevancy", "faithfulness"]
self.validated_list = ["answer_relevancy", "faithfulness", "answer_correctness"]

async def a_measure(self, test_case: Dict):
return self.measure(test_case)
Expand Down

0 comments on commit dfc2c1e

Please sign in to comment.