diff --git a/evals/evaluation/rag_eval/evaluator.py b/evals/evaluation/rag_eval/evaluator.py
index 8e8632db..85d7af81 100644
--- a/evals/evaluation/rag_eval/evaluator.py
+++ b/evals/evaluation/rag_eval/evaluator.py
@@ -129,6 +129,9 @@ def remove_invalid(self, results: list[dict]) -> list[dict]:
"""Remove invalid results from the list and return the cleaned results."""
return [result for result in results if result["valid"]]
+ def get_template(self):
+ raise NotImplementedError("Depends on the specific dataset.")
+
def send_request(self, data, arguments):
service_url = arguments.service_url
headers = {"Content-Type": "application/json"}
@@ -138,14 +141,18 @@ def send_request(self, data, arguments):
json_data["stream"] = False
json_data["temperature"] = arguments.temperature
json_data["max_new_tokens"] = arguments.max_new_tokens
+ json_data["chat_template"] = self.get_template()
json_data = json.dumps(json_data)
response = requests.post(service_url, data=json_data, headers=headers)
if response.ok:
- return response.json()["choices"][0]["message"]["content"]
+ return self.post_process(response.json()["choices"][0]["message"]["content"])
else:
print(f"Request for pipeline failed due to {response.text}.")
return ""
+ def post_process(self, result):
+ return result
+
def get_retrieved_documents(self, data, arguments):
query = self.get_query(data)
data = {"text": query}
@@ -203,7 +210,7 @@ def evaluate(self, arguments, sort=True, show_progress_bar=False, contain_origin
data["retrieved_documents"] = retrieved_documents
generated_text = self.send_request(data, arguments)
data["generated_text"] = generated_text
- result = {"id": data["ID"], **self.scoring(data, arguments.llm_endpoint)}
+ result = {"id": data["ID"], **self.scoring(data)}
if contain_original_data:
result["original_data"] = data
results.append(result)
diff --git a/evals/evaluation/rag_eval/examples/eval_crud.py b/evals/evaluation/rag_eval/examples/eval_crud.py
index 80e67173..4a4ac8e6 100644
--- a/evals/evaluation/rag_eval/examples/eval_crud.py
+++ b/evals/evaluation/rag_eval/examples/eval_crud.py
@@ -9,6 +9,7 @@
import os
from evals.evaluation.rag_eval import Evaluator
+from evals.evaluation.rag_eval.template import CRUDTemplate
class CRUD_Evaluator(Evaluator):
@@ -60,6 +61,23 @@ def get_document(self, data: dict):
)
return document
+ def get_template(self):
+ if self.task == "summarization":
+ template = CRUDTemplate.get_summarization_template()
+ elif self.task == "question_answering":
+ template = CRUDTemplate.get_question_answering_template()
+ elif self.task == "continuation":
+ template = CRUDTemplate.get_continuation_template()
+ else:
+ raise NotImplementedError(
+ f"Unknown task {self.task}, only support "
+ "summarization, question_answering, continuation and hallucinated_modified."
+ )
+ return template
+
+ def post_process(self, result):
+ return result.split("")[-1].split("")[0].strip()
+
def args_parser():
parser = argparse.ArgumentParser()
@@ -128,7 +146,7 @@ def main():
)
output_save_path = os.path.join(args.output_dir, f"{task}.json")
evaluator = CRUD_Evaluator(
- dataset=dataset, output_save_path=output_save_path, task=task, llm_endpoint=args.llm_endpoint
+ dataset=dataset, output_path=output_save_path, task=task, llm_endpoint=args.llm_endpoint
)
if args.ingest_docs:
CRUD_Evaluator.ingest_docs(args.docs_path, args.database_endpoint, args.chunk_size, args.chunk_overlap)
diff --git a/evals/evaluation/rag_eval/template.py b/evals/evaluation/rag_eval/template.py
new file mode 100644
index 00000000..a06ffd62
--- /dev/null
+++ b/evals/evaluation/rag_eval/template.py
@@ -0,0 +1,59 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+
+class CRUDTemplate:
+ @staticmethod
+ def get_question_answering_template():
+ return """你是一位新闻编辑,现在,你被提供了1个问题,和根据这些问题检索到的文档,请分别检索内容和你自身的知识回答这些问题。以下是个例子:
+
+问题:上海和成都市体育局在促进体育消费和全民健身运动方面有哪些相似和不同的措施?
+
+检索文档: 在第15个全民健身日来临之际,上海市体育局将联合美团、大众点评发放500万元体育消费券,3000多家上海本地运动门店参与其中,共同点燃全民健身运动热情,促进体育消费增长。▲8月5日上午10点,上海市体育局将联合美团、大众点评发放新一轮体育消费券2023年上海体育消费券以“全民优惠健身,共享美好生活”为主题,在8月5日-9月3日期间分四期进行发放。第一期消费券发放时间为8月5日10:00-8月13日24:00,第二期消费券发放时间为8月14日-8月20日,第三期8月21日-8月27日,第四期8月28日-9月3日。实时定位在上海的消费者,可以在发放时间内进入美团、大众点评App,搜索“上海体育消费券”进行领取。为满足消费者更多个性化的需求,本轮体育消费券活动准备了满200减80、满120减50、满60减30、满40减20、满20减10和满10减5共六个面额的消费券,消费者可按需领用,先到先得。每位消费者每期最多可领取3张消费券,且每位消费者同时最多可持有3张。据“上海体育”公众号介绍,本次体育消费券适用场景多、覆盖范围广、优惠力度大。在发布会上,成都市体育局副局长陈志介绍,以成都大运会筹办举办为契机,成都积极开展“爱成都·迎大运”“运动成都·悦动生活”“万千商家齐参与”等主题体育消费促进活动,发放各类体育消费券和惠民运动券,促进体育消费持续稳步增长。2022年成都体育消费总规模为578.6亿元,居民人均体育消费为2720.6元。 ▲8月4日,成都大运会体操项目女子个人全能决赛看台上,观众为比赛队员加油 资料配图 摄影 陶轲 为持续激发体育消费活力和增长潜力,下一步,成都将持续深化体育消费试点工作,积极推进体育消费提质扩容。启动户外运动季活动,发布十大最受欢迎时尚运动消费场景。 具体而言,陈志介绍说,成都将加快推动“体育+会展+消费”平台建设,办好中国(成都)生活体育大会、“巴山蜀水·运动川渝”体育旅游休闲消费季、世界赛事名城发展大会、中国国际体育用品博览会等重大体育展会活动,为城市体育消费增长提供更多资源链接。
+
+回答:上海市体育局联合美团、大众点评发放了总额500万元的体育消费券,覆盖3000多家本地运动门店,并设置了不同面额的消费券供消费者领取。而成都市体育局则是利用成都大运会的契机发放各类体育消费券和惠民运动券,同时计划通过举办大型体育展会活动和推动“体育+会展+消费”平台建设来进一步促进体育消费的提质扩容。
+
+问题:{question}
+
+检索到的文档:{context}
+
+请给出你的回答(回答的文本写在之间。
+"""
+
+ @staticmethod
+ def get_summarization_template():
+ return """你是一名新闻工作者。我希望你能根据新闻事件,以及检索到的有关该事件的报告,生成这个新闻事件的摘要。摘要的格式:
+
+
+福建省防指于2023年7月26日召开视频会议,郭宁宁总指挥再次强调了防御超强台风“杜苏芮”的重要性,并进行了再部署和再落实工作。会议要求各级各部门要紧密合作,全力保障人民生命财产安全。同时,国家防总办公室也组织了防汛防台风专题视频会商调度,与相关部门共同研判台风“杜苏芮”的发展态势,并调度部署重点地区的防汛防台风工作。根据会商研判,台风“杜苏芮”将于28日早晨到上午在福建福清到广东惠来一带沿海登陆,具有风浪大、降雨强度大、影响范围广等特点,可能影响多个省市。防汛防台风形势严峻,任务艰巨。
+
+
+现在新闻事件是:
+
+{question}
+
+现在我检索到的文档是:
+
+{context}
+
+请你完成要该事件的摘要(摘要的文本写在之间):
+"""
+
+ @staticmethod
+ def get_continuation_template():
+ return """你是一名新华社新闻工作者。我希望你能辅助我完成一篇新闻的撰写。
+
+请你根据我已经写好的文本,和检索到的文档,为我续写一段话。
+
+请注意,续写文本的长度与已经写好的文本长度大致相当,续写的文本不要出现已经写好的文本的内容!续写的文本要和已经写好的文本具有连贯性!
+
+现在我检索到的文档是:
+
+{context}
+
+现在我已经写好的文本是:
+
+{question}
+
+续写文本:
+"""
diff --git a/evals/metrics/utils.py b/evals/metrics/utils.py
index 42711047..4af93caa 100644
--- a/evals/metrics/utils.py
+++ b/evals/metrics/utils.py
@@ -6,6 +6,7 @@
from typing import Any, List, Optional, Tuple, Union
import evaluate
+import jieba
from pydantic import BaseModel
@@ -84,10 +85,13 @@ def wrapper(*args, **kwargs):
return wrapper
+tokenizer = lambda text: list(jieba.cut(text))
+
+
@catch_all_exceptions
def bleu_score(continuation: str, reference: str, with_penalty=False) -> float:
bleu = evaluate.load(os.path.join(os.path.dirname(__file__), "bleu"))
- results = bleu.compute(predictions=[continuation], references=[[reference]])
+ results = bleu.compute(predictions=[continuation], references=[[reference]], tokenizer=tokenizer)
bleu_avg = results["bleu"]
bleu1 = results["precisions"][0]
@@ -105,6 +109,8 @@ def bleu_score(continuation: str, reference: str, with_penalty=False) -> float:
@catch_all_exceptions
def rougeL_score(continuation: str, reference: str) -> float:
rouge = evaluate.load(os.path.join(os.path.dirname(__file__), "rouge"))
- results = rouge.compute(predictions=[continuation], references=[[reference]], rouge_types=["rougeL"])
+ results = rouge.compute(
+ predictions=[continuation], references=[[reference]], tokenizer=tokenizer, rouge_types=["rougeL"]
+ )
score = results["rougeL"]
return score
diff --git a/requirements.txt b/requirements.txt
index 637f8a96..b9b71e0a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,6 @@
bigcode-eval@git+https://github.com/bigcode-project/bigcode-evaluation-harness.git@e5c2f31625223431d7987f43b70b75b9d26ba118
evaluate
+jieba
langchain_community
langchain_huggingface
lm-eval==0.4.3