-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enable autorag to automatically generate the evaluation dataset and e…
…valuate the RAG system (#36) Signed-off-by: XuhuiRen <[email protected]>
- Loading branch information
Showing
11 changed files
with
858 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
ground_truth_file: ./ground_truth.jsonl | ||
use_openai_key: False | ||
search_type: [similarity, mmr] | ||
k: [1] | ||
fetch_k: [5] | ||
score_threshold: [0.3] | ||
top_n: [1] | ||
temperature: [0.01] | ||
top_k: [1, 3, 5] | ||
top_p: [0.1] | ||
repetition_penalty: [1.0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
{"question": "What are Nike's primary business activities as of the fiscal year ended May 31, 2023?", "context": ["Our principal business activity is the design, development and worldwide marketing and selling of athletic footwear, apparel, equipment, accessories and services."], "ground_truth": "Nike's primary business activities include the design, development, worldwide marketing, and selling of athletic footwear, apparel, equipment, accessories, and services."} | ||
{"question": "How does Nike categorize its product offerings?", "context": ["Our NIKE Brand product offerings are aligned around our consumer construct focused on Men's, Women's and Kids'. We also design products specifically for the Jordan Brand and Converse."], "ground_truth": "Nike categorizes its product offerings around consumer constructs focused on Men's, Women's, and Kids'. They also design products specifically for the Jordan Brand and Converse."} | ||
{"question": "What was Nike's total revenue from non-U.S. operations for fiscal year 2023?", "context": ["For fiscal 2023, non-U.S. NIKE Brand and Converse sales accounted for approximately 57% of total revenues."], "ground_truth": "For fiscal year 2023, non-U.S. operations accounted for approximately 57% of Nike's total revenues."} | ||
{"question": "How does Nike ensure the innovation and quality of its products?", "context": ["We place considerable emphasis on innovation and high-quality construction in the development and manufacturing of our products."], "ground_truth": "Nike emphasizes technical innovation and high-quality construction in the development and manufacturing of its products. They employ specialists in various fields and utilize research committees and advisory boards comprising athletes and other experts."} | ||
{"question": "What are the risks associated with Nike's international operations?", "context": ["Our international operations and sources of supply are subject to the usual risks of doing business abroad, such as the implementation of, or potential changes in, foreign and domestic trade policies."], "ground_truth": "Nike's international operations are subject to risks such as changes in foreign and domestic trade policies, increases in import duties, and political and economic instability, among others."} | ||
{"question": "How does Nike view the role of intellectual property in its business strategy?", "context": ["We believe that our intellectual property rights are important to our brand, our success and our competitive position."], "ground_truth": "Nike considers its intellectual property rights critical to its brand, success, and competitive position. They actively pursue protection of these rights and vigorously defend them against third-party infringement."} | ||
{"question": "What is Nike's approach to diversity, equity, and inclusion within its workforce?", "context": ["Diversity, equity and inclusion ('DE&I') is a strategic priority for NIKE and we are committed to having an increasingly diverse team and culture."], "ground_truth": "Nike prioritizes fostering an inclusive and accessible workplace, aiming to expand representation across all dimensions of diversity. They have specific goals for increasing representation among women globally and U.S. racial and ethnic minorities by fiscal 2025."} | ||
{"question": "How does Nike address the environmental impact of its operations?", "context": ["Our mission is aligned with our deep commitment to maintaining an environment where all NIKE employees have the opportunity to reach their full potential."], "ground_truth": "Nike is focused on sustainability, aiming to create products more sustainably, such as through using environmentally friendly materials and processes, and investing in global communities to promote a more equitable future."} | ||
{"question": "What financial impact did Nike's U.S. operations have in fiscal year 2023?", "context": ["For fiscal 2023, NIKE Brand and Converse sales in the United States accounted for approximately 43% of total revenues."], "ground_truth": "Nike Brand and Converse sales in the United States accounted for approximately 43% of total revenues for fiscal 2023."} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
set -x | ||
|
||
function main { | ||
|
||
init_params "$@" | ||
run_benchmark | ||
|
||
} | ||
|
||
# init params | ||
function init_params { | ||
search_type="similarity" | ||
k=1 | ||
fetch_k=5 | ||
score_threshold=0.3 | ||
top_n=1 | ||
max_chuck_size=256 | ||
temperature=0.01 | ||
top_k=1 | ||
top_p=0.1 | ||
repetition_penalty=1.0 | ||
|
||
for var in "$@" | ||
do | ||
case $var in | ||
--ground_truth_file=*) | ||
ground_truth_file=$(echo $var |cut -f2 -d=) | ||
;; | ||
--use_openai_key=*) | ||
use_openai_key=$(echo $var |cut -f2 -d=) | ||
;; | ||
--search_type=*) | ||
search_type=$(echo $var |cut -f2 -d=) | ||
;; | ||
--k=*) | ||
k=$(echo $var |cut -f2 -d=) | ||
;; | ||
--fetch_k=*) | ||
fetch_k=$(echo $var |cut -f2 -d=) | ||
;; | ||
--score_threshold=*) | ||
score_threshold=$(echo ${var} |cut -f2 -d=) | ||
;; | ||
--top_n=*) | ||
top_n=$(echo ${var} |cut -f2 -d=) | ||
;; | ||
--temperature=*) | ||
temperature=$(echo $var |cut -f2 -d=) | ||
;; | ||
--top_k=*) | ||
top_k=$(echo $var |cut -f2 -d=) | ||
;; | ||
--top_p=*) | ||
top_p=$(echo $var |cut -f2 -d=) | ||
;; | ||
--repetition_penalty=*) | ||
repetition_penalty=$(echo ${var} |cut -f2 -d=) | ||
;; | ||
esac | ||
done | ||
|
||
} | ||
|
||
# run_benchmark | ||
function run_benchmark { | ||
|
||
if [[ ${use_openai_key} == True ]]; then | ||
use_openai_key="--use_openai_key" | ||
else | ||
use_openai_key="" | ||
fi | ||
|
||
python -u ../evaluation/autorag/evaluation/ragas_evaluation_benchmark.py \ | ||
--ground_truth_file ${ground_truth_file} \ | ||
--input_path ${input_path} \ | ||
--use_openai_key ${use_openai_key} \ | ||
--search_type ${search_type} \ | ||
--k ${k} \ | ||
--fetch_k ${fetch_k} \ | ||
--score_threshold ${score_threshold} \ | ||
--top_n ${top_n} \ | ||
--temperature ${temperature} \ | ||
--top_k ${top_k} \ | ||
--top_p ${top_p} \ | ||
--repetition_penalty ${repetition_penalty} | ||
} | ||
|
||
main "$@" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import argparse | ||
import os | ||
import subprocess | ||
|
||
import jsonlines | ||
import yaml | ||
|
||
|
||
def read_yaml_file(file_path): | ||
with open(file_path, "r") as stream: | ||
try: | ||
return yaml.safe_load(stream) | ||
except yaml.YAMLError as exc: | ||
print(exc) | ||
|
||
|
||
if __name__ == "__main__": | ||
if os.path.exists("result_ragas.jsonl"): | ||
os.remove("result_ragas.jsonl") | ||
script_path = "ragas_benchmark.sh" | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--config", type=str) | ||
args = parser.parse_args() | ||
|
||
data = read_yaml_file(args.config) | ||
data = {k: [str(item) for item in v] if isinstance(v, list) else str(v) for k, v in data.items()} | ||
|
||
ground_truth_file = data["ground_truth_file"] | ||
use_openai_key = data["use_openai_key"] | ||
search_types = data["search_type"] | ||
ks = data["k"] | ||
fetch_ks = data["fetch_k"] | ||
score_thresholds = data["score_threshold"] | ||
top_ns = data["top_n"] | ||
temperatures = data["temperature"] | ||
top_ks = data["top_k"] | ||
top_ps = data["top_p"] | ||
repetition_penaltys = data["repetition_penalty"] | ||
|
||
for search_type in search_types: | ||
for k in ks: | ||
for fetch_k in fetch_ks: | ||
for score_threshold in score_thresholds: | ||
for top_n in top_ns: | ||
for temperature in temperatures: | ||
for top_k in top_ks: | ||
for top_p in top_ps: | ||
for repetition_penalty in repetition_penaltys: | ||
subprocess.run( | ||
[ | ||
"bash", | ||
script_path, | ||
"--ground_truth_file=" + ground_truth_file, | ||
"--use_openai_key=" + str(use_openai_key), | ||
"--search_type=" + search_type, | ||
"--k=" + k, | ||
"--fetch_k=" + fetch_k, | ||
"--score_threshold=" + score_threshold, | ||
"--top_n=" + top_n, | ||
"--temperature=" + temperature, | ||
"--top_k=" + top_k, | ||
"--top_p=" + top_p, | ||
"--repetition_penalty=" + repetition_penalty, | ||
], | ||
stdout=subprocess.DEVNULL, | ||
stderr=subprocess.DEVNULL, | ||
) |
62 changes: 62 additions & 0 deletions
62
evals/evaluation/autorag/data_generation/gen_answer_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import logging | ||
import re | ||
|
||
import jsonlines | ||
import torch | ||
from modelscope import AutoModelForCausalLM, AutoTokenizer # pylint: disable=E0401 | ||
|
||
from .prompt_dict import TRUTHGENERATE_PROMPT | ||
|
||
|
||
def load_documents(document_file_jsonl_path): | ||
document_list = [] | ||
with open(document_file_jsonl_path) as file: | ||
for stu in jsonlines.Reader(file): | ||
passages = [stu["query"], stu["pos"][0]] | ||
document_list.append(passages) | ||
return document_list | ||
|
||
|
||
def answer_generate(llm, base_dir, file_json_path, generation_config): | ||
documents = load_documents(base_dir) | ||
|
||
try: | ||
if isinstance(llm, str): | ||
use_endpoint = False | ||
tokenizer = AutoTokenizer.from_pretrained(llm) | ||
llm = AutoModelForCausalLM.from_pretrained(llm, device_map="auto", torch_dtype=torch.float16) | ||
llm.eval() | ||
else: | ||
use_endpoint = True | ||
llm = llm | ||
except: | ||
print("Please check the setting llm!") | ||
|
||
for question, context in enumerate(documents): | ||
if context and question: | ||
prompt = TRUTHGENERATE_PROMPT.format(question=question, context=context) | ||
if not use_endpoint: | ||
with torch.no_grad(): | ||
model_input = tokenizer(prompt, return_tensors="pt") | ||
res = llm.generate(**model_input, generation_config=generation_config)[0] | ||
res = tokenizer.decode(res, skip_special_tokens=True) | ||
else: | ||
res = llm.invoke(prompt) | ||
|
||
res = res[res.find("Generated ground_truth:") :] | ||
res = re.sub("Generated ground_truth:", "", res) | ||
res = re.sub("---", "", res) | ||
|
||
result_str = res.replace("#", " ").replace(r"\t", " ").replace("\n", " ").replace("\n\n", " ").strip() | ||
|
||
if result_str and not result_str.isspace(): | ||
data = { | ||
"question": question, | ||
"context": [context], | ||
"ground_truth": result_str, | ||
} | ||
with jsonlines.open(file_json_path, "a") as file_json: | ||
file_json.write(data) |
105 changes: 105 additions & 0 deletions
105
evals/evaluation/autorag/data_generation/gen_eval_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import argparse | ||
import os | ||
|
||
from comps.dataprep.utils import document_loader | ||
from langchain_community.llms import HuggingFaceEndpoint | ||
from sentence_transformers import SentenceTransformer | ||
from transformers import GenerationConfig | ||
|
||
from .gen_answer_dataset import answer_generate | ||
from .gen_hard_negative import mine_hard_negatives | ||
from .llm_generate_raw_data import raw_data_generation | ||
from .utils import similarity_check | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--llm", type=str) | ||
parser.add_argument("--embedding_model", type=str) | ||
parser.add_argument("--input", type=str) | ||
parser.add_argument("--output", type=str, default="./data") | ||
|
||
parser.add_argument("--temperature", type=float, default=0.8) | ||
parser.add_argument("--top_p", type=float, default=0.9) | ||
parser.add_argument("--top_k", type=int, default=40) | ||
parser.add_argument("--repetition_penalty", type=float, default=2.0) | ||
parser.add_argument("--max_new_tokens", type=int, default=48) | ||
parser.add_argument("--do_sample", type=bool, default=True) | ||
parser.add_argument("--num_beams", type=int, default=2) | ||
parser.add_argument("--num_return_sequences", type=int, default=2) | ||
parser.add_argument("--use_cache", type=bool, default=True) | ||
|
||
parser.add_argument("--range_for_sampling", type=str, default="2-10") | ||
parser.add_argument("--negative_number", type=int, default=5) | ||
parser.add_argument("--use_gpu_for_searching", type=bool, default=False) | ||
|
||
parser.add_argument("--similarity_threshold", type=float, default=0.6) | ||
|
||
args = parser.parse_args() | ||
|
||
llm_model = args.llm | ||
input_path = args.input | ||
output = args.output | ||
|
||
generation_config = GenerationConfig( | ||
temperature=args.temperature, | ||
top_p=args.top_p, | ||
top_k=args.top_k, | ||
repetition_penalty=args.repetition_penalty, | ||
max_new_tokens=args.max_new_tokens, | ||
do_sample=args.do_sample, | ||
num_beams=args.num_beams, | ||
num_return_sequences=args.num_return_sequences, | ||
use_cache=args.use_cache, | ||
) | ||
|
||
embedding_model = SentenceTransformer(args.embedding_model) | ||
|
||
try: | ||
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080") | ||
llm = HuggingFaceEndpoint( | ||
endpoint_url=llm_endpoint, | ||
max_new_tokens=512, | ||
top_k=args.top_k, | ||
top_p=args.top_p, | ||
typical_p=args.typical_p, | ||
temperature=args.temperature, | ||
repetition_penalty=args.repetition_penalty, | ||
streaming=args.streaming, | ||
timeout=600, | ||
) | ||
except: | ||
print("Did not find the llm endpoint service, load model from huggingface hub as instead.") | ||
|
||
try: | ||
if not os.path.exists(output): | ||
os.mkdir(output) | ||
else: | ||
if os.path.exists(os.path.join(output, "raw.jsonl")): | ||
os.remove(os.path.join(output, "raw.jsonl")) | ||
if os.path.exists(os.path.join(output, "minedHN.jsonl")): | ||
os.remove(os.path.join(output, "minedHN.jsonl")) | ||
if os.path.exists(os.path.join(output, "minedHN_split.jsonl")): | ||
os.remove(os.path.join(output, "minedHN_split.jsonl")) | ||
except: | ||
pass | ||
|
||
output_path = os.path.join(output, "raw_query.jsonl") | ||
raw_data_generation(llm, input_path, output_path, generation_config) | ||
|
||
output_hn_path = os.path.join(output, "query_doc.jsonl") | ||
mine_hard_negatives( | ||
embedding_model, | ||
output_path, | ||
output_hn_path, | ||
args.range_for_sampling, | ||
args.negative_number, | ||
) | ||
|
||
output_json_split_path = os.path.join(output, "query_doc_cleaned.jsonl") | ||
similarity_check(output_hn_path, output_json_split_path, embedding_model, args.similarity_threshold) | ||
|
||
output_answer_path = os.path.join(output, "answer.jsonl") | ||
answer_generate(llm, input, output, generation_config) |
Oops, something went wrong.