Skip to content

Commit

Permalink
QA eval dataset as argument, with hotpot and 2wikimultihop as options…
Browse files Browse the repository at this point in the history
…. Json schema validation for datasets.
  • Loading branch information
alekszievr committed Jan 8, 2025
1 parent a6dfff8 commit a67512d
Showing 1 changed file with 65 additions and 13 deletions.
78 changes: 65 additions & 13 deletions evals/eval_on_hotpot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,68 @@
import wget
from deepeval.dataset import EvaluationDataset
from deepeval.test_case import LLMTestCase
from jsonschema import ValidationError, validate
from tqdm import tqdm

import cognee
import evals.deepeval_metrics
from cognee.api.v1.search import SearchType
from cognee.base_config import get_base_config
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from cognee.root_dir import get_absolute_path

qa_datasets = {
"hotpotqa": {
"filename": "hotpot_dev_fullwiki_v1.json",
"URL": "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json"
},
"2wikimultihop": {
"filename": "data/dev.json",
"URL": "https://www.dropbox.com/scl/fi/heid2pkiswhfaqr5g0piw/data.zip?rlkey=ira57daau8lxfj022xvk1irju&e=1"
}
}

qa_json_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"answer": {"type": "string"},
"question": {"type": "string"},
"context": {"type": "array"},
},
"required": ["answer", "question", "context"],
"additionalProperties": True
}
}


def download_qa_dataset(dataset_name: str, dir: str):

if dataset_name not in qa_datasets:
raise ValueError(f"{dataset_name} is not a supported dataset.")

url = qa_datasets[dataset_name]["URL"]

if dataset_name == "2wikimultihop":
raise Exception("Please download 2wikimultihop dataset (data.zip) manually from \
https://www.dropbox.com/scl/fi/heid2pkiswhfaqr5g0piw/data.zip?rlkey=ira57daau8lxfj022xvk1irju&e=1 \
and unzip it.")

wget.download(url, out=dir)


def load_qa_dataset(filepath: Path):

with open(filepath, "r") as file:
dataset = json.load(file)

try:
validate(instance=dataset, schema=qa_json_schema)
except ValidationError as e:
print("File is not a valid QA dataset:", e.message)

return dataset

async def answer_without_cognee(instance):
args = {
Expand All @@ -39,9 +92,8 @@ async def answer_with_cognee(instance):
await cognee.prune.prune_system(metadata=True)

for (title, sentences) in instance["context"]:
await cognee.add("\n".join(sentences), dataset_name = "HotPotQA")

await cognee.cognify("HotPotQA")
await cognee.add("\n".join(sentences), dataset_name = "QA")
await cognee.cognify("QA")

search_results = await cognee.search(
SearchType.INSIGHTS, query_text=instance["question"]
Expand Down Expand Up @@ -80,20 +132,19 @@ async def eval_answers(instances, answers, eval_metric):

return eval_results

async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
base_config = get_base_config()
data_root_dir = base_config.data_root_directory
async def eval_on_QA_dataset(dataset_name: str, answer_provider, num_samples, eval_metric):

data_root_dir = get_absolute_path("../.data")

if not Path(data_root_dir).exists():
Path(data_root_dir).mkdir()

filepath = data_root_dir / Path("hotpot_dev_fullwiki_v1.json")
filename = qa_datasets[dataset_name]["filename"]
filepath = data_root_dir / Path(filename)
if not filepath.exists():
url = 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json'
wget.download(url, out=data_root_dir)
download_qa_dataset(dataset_name, data_root_dir)

with open(filepath, "r") as file:
dataset = json.load(file)
dataset = load_qa_dataset(filepath)

instances = dataset if not num_samples else dataset[:num_samples]
answers = []
Expand All @@ -109,6 +160,7 @@ async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str, choices=list(qa_datasets.keys()), help="Which dataset to evaluate on")
parser.add_argument("--with_cognee", action="store_true")
parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument("--metric", type=str, default="correctness_metric",
Expand All @@ -130,5 +182,5 @@ async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
else:
answer_provider = answer_without_cognee

avg_score = asyncio.run(eval_on_hotpotQA(answer_provider, args.num_samples, metric))
avg_score = asyncio.run(eval_on_QA_dataset(args.dataset, answer_provider, args.num_samples, metric))
print(f"Average {args.metric}: {avg_score}")

0 comments on commit a67512d

Please sign in to comment.