Skip to content

Commit

Permalink
Adding new metrics to ragas offering (#142)
Browse files Browse the repository at this point in the history
* minimized required fields/columns in user data

Signed-off-by: aasavari <[email protected]>
  • Loading branch information
adkakne authored Oct 8, 2024
1 parent 4af0a62 commit d1c1337
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 44 deletions.
107 changes: 64 additions & 43 deletions evals/metrics/ragas/ragas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
# SPDX-License-Identifier: Apache-2.0
#
import os
import re
from typing import Dict, Optional, Union

from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_huggingface import HuggingFaceEndpoint

# import * is only allowed at module level according to python syntax
from ragas.metrics import *


def format_ragas_metric_name(name: str):
return f"{name} (ragas)"
Expand All @@ -29,16 +33,17 @@ def __init__(
self.model = model
self.embeddings = embeddings
self.metrics = metrics
self.validated_list = [
"answer_correctness",
"answer_relevancy",
"answer_similarity",
"context_precision",
"context_recall",
"faithfulness",
"context_utilization",
# "reference_free_rubrics_score",
]

# self.validated_list = [
# "answer_correctness",
# "answer_relevancy",
# "answer_similarity",
# "context_precision",
# "context_recall",
# "faithfulness",
# "context_utilization",
# # "reference_free_rubrics_score",
# ]

async def a_measure(self, test_case: Dict):
return self.measure(test_case)
Expand All @@ -47,37 +52,51 @@ def measure(self, test_case: Dict):
# sends to server
try:
from ragas import evaluate
from ragas.metrics import (
answer_correctness,
answer_relevancy,
answer_similarity,
context_precision,
context_recall,
context_utilization,
faithfulness,
)
from ragas.metrics import ALL_METRICS

self.metric_names = [metric.__class__.__name__ for metric in ALL_METRICS]
self.metric_names = [re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower() for name in self.metric_names]
self.metric_names = list(set(self.metric_names))
# Note - summarization score metric is not working with best open-source LLMs
# Note - which is why we are removing it from our offering at the moment.
self.metric_names.remove("summarization_score")
self.metric_instances = {}
for metric in self.metric_names:
try:
self.metric_instances[metric] = eval(metric)
except:
pass
# from ragas.metrics import (
# answer_correctness,
# answer_relevancy,
# answer_similarity,
# context_precision,
# context_recall,
# context_utilization,
# faithfulness,
# )
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install ragas to use this metric. `pip install ragas`.")
try:
from datasets import Dataset
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install dataset")
self.metrics_instance = {
"answer_correctness": answer_correctness,
"answer_relevancy": answer_relevancy,
"answer_similarity": answer_similarity,
"context_precision": context_precision,
"context_recall": context_recall,
"faithfulness": faithfulness,
"context_utilization": context_utilization,
# "reference_free_rubrics_score": reference_free_rubrics_score,
}
# self.metrics_instance = {
# "answer_correctness": answer_correctness,
# "answer_relevancy": answer_relevancy,
# "answer_similarity": answer_similarity,
# "context_precision": context_precision,
# "context_recall": context_recall,
# "faithfulness": faithfulness,
# "context_utilization": context_utilization,
# # "reference_free_rubrics_score": reference_free_rubrics_score,
# }
# Set LLM model
openai_key = os.getenv("OPENAI_API_KEY", None)
if openai_key is not None:
print("OPENAI_API_KEY is provided, ragas initializes the model by OpenAI.")
self.model = None
if isinstance(self.model, str):
self.chat_model = None
elif isinstance(self.model, str):
print("LLM endpoint: ", self.model)
self.chat_model = HuggingFaceEndpoint(
endpoint_url=self.model,
Expand All @@ -92,36 +111,38 @@ def measure(self, test_case: Dict):
tmp_metrics = []
# check supported list
for metric in self.metrics:
if metric not in self.validated_list:
if metric not in self.metric_names:
raise ValueError(
"metric should be in supported list {}. ".format(self.validated_list)
"metric should be in supported list {}. ".format(self.metric_names)
+ "ClientResponseError raised with LangchainLLM "
+ "when context_precision, context_recall ran. "
+ "Here are the related issues described in ragas "
"https://github.com/explodinggradients/ragas/issues/934, "
+ "https://github.com/explodinggradients/ragas/issues/664."
)
else:
if metric == "answer_relevancy" and self.embeddings is None:
raise ValueError("answer_relevancy metric need provide embeddings model.")
if metric == "AnswerRelevancy" and self.embeddings is None:
raise ValueError("AnswerRelevancy metric need provide embeddings model.")
tmp_metrics.append(self.metrics_instance[metric])
self.metrics = tmp_metrics
else:
self.metrics = [
answer_relevancy,
faithfulness,
answer_correctness,
answer_similarity,
context_precision,
context_recall,
]
self.metrics = list(self.metric_instances.values())
# self.metrics = [
# answer_relevancy,
# faithfulness,
# answer_correctness,
# answer_similarity,
# context_precision,
# context_recall,
# ]
# Find necessary input fields using the given metrics
_required_columns = set()
column_map = { # this column maps new naming style in ragas to their old naming style
"user_input": "question",
"response": "answer",
"reference": "ground_truth",
"retrieved_contexts": "contexts",
"reference_contexts": "reference_contexts",
}
for metric in self.metrics:
if hasattr(metric, "_required_columns"):
Expand Down
2 changes: 1 addition & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ jieba
langchain_community
langchain_huggingface
lm-eval==0.4.3
ragas
ragas==0.1.19
3 changes: 3 additions & 0 deletions tests/test_ragas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ def test_ragas(self):

# Replace this with the actual retrieved context from your RAG pipeline
retrieval_context = ["All customers are eligible for a 30 day full refund at no extra cost."]
reference_context = ["We can only process full refund upto 30 day after the purchase."]
from langchain_community.embeddings import HuggingFaceBgeEmbeddings

embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5")

metric = RagasMetric(threshold=0.5, model=f"http://{host_ip}:{port}", embeddings=embeddings)
test_case = {
"question": ["What if these shoes don't fit?"],
"answer": [actual_output],
"ground_truth": [expected_output],
"contexts": [retrieval_context],
"reference_contexts": [reference_context],
}

metric.measure(test_case)
Expand Down

0 comments on commit d1c1337

Please sign in to comment.