Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/cog 946 abstract eval dataset #418

Merged
merged 10 commits into from
Jan 14, 2025
37 changes: 11 additions & 26 deletions evals/eval_on_hotpot.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import argparse
import asyncio
import json
import statistics
from pathlib import Path

import deepeval.metrics
import wget
from deepeval.dataset import EvaluationDataset
from deepeval.test_case import LLMTestCase
from tqdm import tqdm

import cognee
import evals.deepeval_metrics
from cognee.api.v1.search import SearchType
from cognee.base_config import get_base_config
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from evals.qa_dataset_utils import load_qa_dataset


async def answer_without_cognee(instance):
Expand All @@ -40,12 +36,8 @@ async def answer_with_cognee(instance):
await cognee.prune.prune_system(metadata=True)

for title, sentences in instance["context"]:
await cognee.add("\n".join(sentences), dataset_name="HotPotQA")

for n in range(1, 4):
print(n)

await cognee.cognify("HotPotQA")
await cognee.add("\n".join(sentences), dataset_name="QA")
await cognee.cognify("QA")

search_results = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"])
search_results_second = await cognee.search(
Expand Down Expand Up @@ -85,20 +77,10 @@ async def eval_answers(instances, answers, eval_metric):
return eval_results


async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
base_config = get_base_config()
data_root_dir = base_config.data_root_directory

if not Path(data_root_dir).exists():
Path(data_root_dir).mkdir()

filepath = data_root_dir / Path("hotpot_dev_fullwiki_v1.json")
if not filepath.exists():
url = "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json"
wget.download(url, out=data_root_dir)

with open(filepath, "r") as file:
dataset = json.load(file)
async def eval_on_QA_dataset(
dataset_name_or_filename: str, answer_provider, num_samples, eval_metric
):
dataset = load_qa_dataset(dataset_name_or_filename)

instances = dataset if not num_samples else dataset[:num_samples]
answers = []
Expand All @@ -117,6 +99,7 @@ async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str, help="Which dataset to evaluate on")
parser.add_argument("--with_cognee", action="store_true")
parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument(
Expand All @@ -142,5 +125,7 @@ async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
else:
answer_provider = answer_without_cognee

avg_score = asyncio.run(eval_on_hotpotQA(answer_provider, args.num_samples, metric))
avg_score = asyncio.run(
eval_on_QA_dataset(args.dataset, answer_provider, args.num_samples, metric)
)
print(f"Average {args.metric}: {avg_score}")
82 changes: 82 additions & 0 deletions evals/qa_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from cognee.root_dir import get_absolute_path
import json
import requests
from jsonschema import ValidationError, validate
from pathlib import Path


qa_datasets = {
"hotpotqa": {
"filename": "hotpot_dev_fullwiki_v1.json",
"URL": "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json",
},
"2wikimultihop": {
"filename": "data/dev.json",
"URL": "https://www.dropbox.com/scl/fi/heid2pkiswhfaqr5g0piw/data.zip?rlkey=ira57daau8lxfj022xvk1irju&e=1",
},
}

qa_json_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"answer": {"type": "string"},
"question": {"type": "string"},
"context": {"type": "array"},
},
"required": ["answer", "question", "context"],
"additionalProperties": True,
},
}


def download_qa_dataset(dataset_name: str, filepath: Path):
if dataset_name not in qa_datasets:
raise ValueError(f"{dataset_name} is not a supported dataset.")

url = qa_datasets[dataset_name]["URL"]

if dataset_name == "2wikimultihop":
raise Exception(
"Please download 2wikimultihop dataset (data.zip) manually from \
https://www.dropbox.com/scl/fi/heid2pkiswhfaqr5g0piw/data.zip?rlkey=ira57daau8lxfj022xvk1irju&e=1 \
and unzip it."
)

response = requests.get(url, stream=True)

if response.status_code == 200:
with open(filepath, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Dataset {dataset_name} downloaded and saved to {filepath}")
else:
print(f"Failed to download {dataset_name}. Status code: {response.status_code}")


def load_qa_dataset(dataset_name_or_filename: str):
if dataset_name_or_filename in qa_datasets:
dataset_name = dataset_name_or_filename
filename = qa_datasets[dataset_name]["filename"]

data_root_dir = get_absolute_path("../.data")
if not Path(data_root_dir).exists():
Path(data_root_dir).mkdir()

filepath = data_root_dir / Path(filename)
if not filepath.exists():
download_qa_dataset(dataset_name, filepath)
else:
filename = dataset_name_or_filename
filepath = Path(filename)

with open(filepath, "r") as file:
dataset = json.load(file)

try:
validate(instance=dataset, schema=qa_json_schema)
except ValidationError as e:
print("File is not a valid QA dataset:", e.message)

return dataset
Loading