From 259414add09de4444da88688dcf88f46ebd151c2 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 14 Jan 2025 15:32:27 +0100 Subject: [PATCH 1/2] docs: Update LlamaIndex integration notebook --- .../llama_index_cognee_integration.ipynb | 64 ++++++++----------- 1 file changed, 25 insertions(+), 39 deletions(-) diff --git a/notebooks/llama_index_cognee_integration.ipynb b/notebooks/llama_index_cognee_integration.ipynb index 772c0a8c7..6df6a5980 100644 --- a/notebooks/llama_index_cognee_integration.ipynb +++ b/notebooks/llama_index_cognee_integration.ipynb @@ -1,5 +1,10 @@ { "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EpokQ8Y_5jIJ7HdixZms81Oqgh2sp7-E?usp=sharing)" + }, { "metadata": {}, "cell_type": "markdown", @@ -45,16 +50,14 @@ "### 1. Setting Up the Environment\n", "\n", "Start by importing the required libraries and defining the environment:" - ], - "id": "d0d7a82d729bbef6" + ] }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, - "source": "!pip install llama-index-graph-rag-cognee==0.1.1", - "id": "598b52e384086512" + "source": "!pip install llama-index-graph-rag-cognee==0.1.2" }, { "metadata": {}, @@ -69,8 +72,7 @@ "\n", "if \"OPENAI_API_KEY\" not in os.environ:\n", " os.environ[\"OPENAI_API_KEY\"] = \"\"" - ], - "id": "892a1b1198ec662f" + ] }, { "metadata": {}, @@ -81,8 +83,7 @@ "### 2. Preparing the Dataset\n", "\n", "We’ll use a brief profile of an individual as our sample dataset:" - ], - "id": "a1f16f5ca5249ebb" + ] }, { "metadata": {}, @@ -98,8 +99,7 @@ " text=\"David Thompson, Creative Graphic Designer with over 8 years of experience in visual design and branding.\"\n", " ),\n", " ]" - ], - "id": "198022c34636a3a0" + ] }, { "metadata": {}, @@ -108,8 +108,7 @@ "### 3. Initializing CogneeGraphRAG\n", "\n", "Instantiate the Cognee framework with configurations for LLM, graph, and database providers:" - ], - "id": "781ae78e52ff49a" + ] }, { "metadata": {}, @@ -126,8 +125,7 @@ " relational_db_provider=\"sqlite\",\n", " relational_db_name=\"cognee_db\",\n", ")" - ], - "id": "17e466821ab88d50" + ] }, { "metadata": {}, @@ -136,16 +134,14 @@ "### 4. Adding Data to Cognee\n", "\n", "Load the dataset into the cognee framework:" - ], - "id": "2a55d5be9de0ce81" + ] }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, - "source": "await cogneeRAG.add(documents, \"test\")", - "id": "238b716429aba541" + "source": "await cogneeRAG.add(documents, \"test\")" }, { "metadata": {}, @@ -156,16 +152,14 @@ "### 5. Processing Data into a Knowledge Graph\n", "\n", "Transform the data into a structured knowledge graph:" - ], - "id": "23e5316aa7e5dbc7" + ] }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, - "source": "await cogneeRAG.process_data(\"test\")", - "id": "c3b3063d428b07a2" + "source": "await cogneeRAG.process_data(\"test\")" }, { "metadata": {}, @@ -176,8 +170,7 @@ "### 6. Performing Searches\n", "\n", "### Answer prompt based on knowledge graph approach:" - ], - "id": "e32327de54e98dc8" + ] }, { "metadata": {}, @@ -190,14 +183,12 @@ "print(\"\\n\\nAnswer based on knowledge graph:\\n\")\n", "for result in search_results:\n", " print(f\"{result}\\n\")" - ], - "id": "fddbf5916d1e50e5" + ] }, { "metadata": {}, "cell_type": "markdown", - "source": "### Answer prompt based on RAG approach:", - "id": "9246aed7f69ceb7e" + "source": "### Answer prompt based on RAG approach:" }, { "metadata": {}, @@ -210,14 +201,12 @@ "print(\"\\n\\nAnswer based on RAG:\\n\")\n", "for result in search_results:\n", " print(f\"{result}\\n\")" - ], - "id": "fe77c7a7c57fe4e4" + ] }, { "metadata": {}, "cell_type": "markdown", - "source": "In conclusion, the results demonstrate a significant advantage of the knowledge graph-based approach (Graphrag) over the RAG approach. Graphrag successfully identified all the mentioned individuals across multiple documents, showcasing its ability to aggregate and infer information from a global context. In contrast, the RAG approach was limited to identifying individuals within a single document due to its chunking-based processing constraints. This highlights Graphrag's superior capability in comprehensively resolving queries that span across a broader corpus of interconnected data.", - "id": "89cc99628392eb99" + "source": "In conclusion, the results demonstrate a significant advantage of the knowledge graph-based approach (Graphrag) over the RAG approach. Graphrag successfully identified all the mentioned individuals across multiple documents, showcasing its ability to aggregate and infer information from a global context. In contrast, the RAG approach was limited to identifying individuals within a single document due to its chunking-based processing constraints. This highlights Graphrag's superior capability in comprehensively resolving queries that span across a broader corpus of interconnected data." }, { "metadata": {}, @@ -226,8 +215,7 @@ "### 7. Finding Related Nodes\n", "\n", "Explore relationships in the knowledge graph:" - ], - "id": "44c9b67c09763610" + ] }, { "metadata": {}, @@ -240,8 +228,7 @@ "print(\"\\n\\nRelated nodes are:\\n\")\n", "for node in related_nodes:\n", " print(f\"{node}\\n\")" - ], - "id": "efbc1511586f46fe" + ] }, { "metadata": {}, @@ -274,9 +261,8 @@ "\n", "Try running it yourself\n", "\n", - "Join cognee community" - ], - "id": "d0f82c2c6eb7793" + "[join the cognee community](https://discord.gg/tV7pr5XSj7)" + ] } ], "metadata": {}, From 6653d7355656fd43ab40c87216781a9d4b829b13 Mon Sep 17 00:00:00 2001 From: alekszievr <44192193+alekszievr@users.noreply.github.com> Date: Wed, 15 Jan 2025 10:45:55 +0100 Subject: [PATCH 2/2] Feat/cog 950 improve metric selection (#435) * QA eval dataset as argument, with hotpot and 2wikimultihop as options. Json schema validation for datasets. * Load dataset file by filename, outsource utilities * restructure metric selection * Add comprehensiveness, diversity and empowerment metrics * add promptfoo as an option * refactor RAG solution in eval;2C * LLM as a judge metrics implemented in a uniform way * Use requests.get instead of wget * clean up promptfoo config template * minor fixes * get promptfoo path instead of hardcoding * minor fixes * Add LLM as a judge prompts * Minor refactor and logger usage --- .../llm/prompts/llm_judge_prompts.py | 9 + evals/deepeval_metrics.py | 47 +++++- evals/eval_on_hotpot.py | 91 +++++----- evals/promptfoo_config_template.yaml | 7 + evals/promptfoo_metrics.py | 53 ++++++ evals/promptfoo_wrapper.py | 157 ++++++++++++++++++ evals/promptfooprompt.json | 10 ++ evals/qa_dataset_utils.py | 4 +- evals/qa_metrics_utils.py | 51 ++++++ 9 files changed, 375 insertions(+), 54 deletions(-) create mode 100644 cognee/infrastructure/llm/prompts/llm_judge_prompts.py create mode 100644 evals/promptfoo_config_template.yaml create mode 100644 evals/promptfoo_metrics.py create mode 100644 evals/promptfoo_wrapper.py create mode 100644 evals/promptfooprompt.json create mode 100644 evals/qa_metrics_utils.py diff --git a/cognee/infrastructure/llm/prompts/llm_judge_prompts.py b/cognee/infrastructure/llm/prompts/llm_judge_prompts.py new file mode 100644 index 000000000..9b94ebdad --- /dev/null +++ b/cognee/infrastructure/llm/prompts/llm_judge_prompts.py @@ -0,0 +1,9 @@ +# LLM-as-a-judge metrics as described here: https://arxiv.org/abs/2404.16130 + +llm_judge_prompts = { + "correctness": "Determine whether the actual output is factually correct based on the expected output.", + "comprehensiveness": "Determine how much detail the answer provides to cover all the aspects and details of the question.", + "diversity": "Determine how varied and rich the answer is in providing different perspectives and insights on the question.", + "empowerment": "Determine how well the answer helps the reader understand and make informed judgements about the topic.", + "directness": "Determine how specifically and clearly the answer addresses the question.", +} diff --git a/evals/deepeval_metrics.py b/evals/deepeval_metrics.py index 9ce1e9e4f..51d6c9181 100644 --- a/evals/deepeval_metrics.py +++ b/evals/deepeval_metrics.py @@ -2,14 +2,57 @@ from deepeval.test_case import LLMTestCase, LLMTestCaseParams from evals.official_hotpot_metrics import exact_match_score, f1_score +from cognee.infrastructure.llm.prompts.llm_judge_prompts import llm_judge_prompts correctness_metric = GEval( name="Correctness", model="gpt-4o-mini", evaluation_params=[LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT], - evaluation_steps=[ - "Determine whether the actual output is factually correct based on the expected output." + evaluation_steps=[llm_judge_prompts["correctness"]], +) + +comprehensiveness_metric = GEval( + name="Comprehensiveness", + model="gpt-4o-mini", + evaluation_params=[ + LLMTestCaseParams.INPUT, + LLMTestCaseParams.ACTUAL_OUTPUT, + LLMTestCaseParams.EXPECTED_OUTPUT, + ], + evaluation_steps=[llm_judge_prompts["comprehensiveness"]], +) + +diversity_metric = GEval( + name="Diversity", + model="gpt-4o-mini", + evaluation_params=[ + LLMTestCaseParams.INPUT, + LLMTestCaseParams.ACTUAL_OUTPUT, + LLMTestCaseParams.EXPECTED_OUTPUT, + ], + evaluation_steps=[llm_judge_prompts["diversity"]], +) + +empowerment_metric = GEval( + name="Empowerment", + model="gpt-4o-mini", + evaluation_params=[ + LLMTestCaseParams.INPUT, + LLMTestCaseParams.ACTUAL_OUTPUT, + LLMTestCaseParams.EXPECTED_OUTPUT, + ], + evaluation_steps=[llm_judge_prompts["empowerment"]], +) + +directness_metric = GEval( + name="Directness", + model="gpt-4o-mini", + evaluation_params=[ + LLMTestCaseParams.INPUT, + LLMTestCaseParams.ACTUAL_OUTPUT, + LLMTestCaseParams.EXPECTED_OUTPUT, ], + evaluation_steps=[llm_judge_prompts["directness"]], ) diff --git a/evals/eval_on_hotpot.py b/evals/eval_on_hotpot.py index ee2435e6b..54dcaffd0 100644 --- a/evals/eval_on_hotpot.py +++ b/evals/eval_on_hotpot.py @@ -1,37 +1,21 @@ import argparse import asyncio import statistics -import deepeval.metrics from deepeval.dataset import EvaluationDataset from deepeval.test_case import LLMTestCase from tqdm import tqdm - +import logging import cognee -import evals.deepeval_metrics from cognee.api.v1.search import SearchType from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt from evals.qa_dataset_utils import load_qa_dataset +from evals.qa_metrics_utils import get_metric - -async def answer_without_cognee(instance): - args = { - "question": instance["question"], - "context": instance["context"], - } - user_prompt = render_prompt("context_for_question.txt", args) - system_prompt = read_query_prompt("answer_hotpot_question.txt") - - llm_client = get_llm_client() - answer_prediction = await llm_client.acreate_structured_output( - text_input=user_prompt, - system_prompt=system_prompt, - response_model=str, - ) - return answer_prediction +logger = logging.getLogger(__name__) -async def answer_with_cognee(instance): +async def get_context_with_cognee(instance): await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) @@ -45,9 +29,21 @@ async def answer_with_cognee(instance): ) search_results = search_results + search_results_second + search_results_str = "\n".join([context_item["text"] for context_item in search_results]) + + return search_results_str + + +async def get_context_without_cognee(instance): + return instance["context"] + + +async def answer_qa_instance(instance, context_provider): + context = await context_provider(instance) + args = { "question": instance["question"], - "context": search_results, + "context": context, } user_prompt = render_prompt("context_for_question.txt", args) system_prompt = read_query_prompt("answer_hotpot_using_cognee_search.txt") @@ -62,7 +58,7 @@ async def answer_with_cognee(instance): return answer_prediction -async def eval_answers(instances, answers, eval_metric): +async def deepeval_answers(instances, answers, eval_metric): test_cases = [] for instance, answer in zip(instances, answers): @@ -77,18 +73,13 @@ async def eval_answers(instances, answers, eval_metric): return eval_results -async def eval_on_QA_dataset( - dataset_name_or_filename: str, answer_provider, num_samples, eval_metric -): - dataset = load_qa_dataset(dataset_name_or_filename) - - instances = dataset if not num_samples else dataset[:num_samples] +async def deepeval_on_instances(instances, context_provider, eval_metric): answers = [] for instance in tqdm(instances, desc="Getting answers"): - answer = await answer_provider(instance) + answer = await answer_qa_instance(instance, context_provider) answers.append(answer) - eval_results = await eval_answers(instances, answers, eval_metric) + eval_results = await deepeval_answers(instances, answers, eval_metric) avg_score = statistics.mean( [result.metrics_data[0].score for result in eval_results.test_results] ) @@ -96,36 +87,36 @@ async def eval_on_QA_dataset( return avg_score +async def eval_on_QA_dataset( + dataset_name_or_filename: str, context_provider, num_samples, eval_metric_name +): + dataset = load_qa_dataset(dataset_name_or_filename) + + eval_metric = get_metric(eval_metric_name) + instances = dataset if not num_samples else dataset[:num_samples] + + if eval_metric_name.startswith("promptfoo"): + return await eval_metric.measure(instances, context_provider) + else: + return await deepeval_on_instances(instances, context_provider, eval_metric) + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--dataset", type=str, help="Which dataset to evaluate on") + parser.add_argument("--dataset", type=str, required=True, help="Which dataset to evaluate on") parser.add_argument("--with_cognee", action="store_true") parser.add_argument("--num_samples", type=int, default=500) - parser.add_argument( - "--metric", - type=str, - default="correctness_metric", - help="Valid options are Deepeval metrics (e.g. AnswerRelevancyMetric) \ - and metrics defined in evals/deepeval_metrics.py, e.g. f1_score_metric", - ) + parser.add_argument("--metric_name", type=str, default="Correctness") args = parser.parse_args() - try: - metric_cls = getattr(deepeval.metrics, args.metric) - metric = metric_cls() - except AttributeError: - metric = getattr(evals.deepeval_metrics, args.metric) - if isinstance(metric, type): - metric = metric() - if args.with_cognee: - answer_provider = answer_with_cognee + context_provider = get_context_with_cognee else: - answer_provider = answer_without_cognee + context_provider = get_context_without_cognee avg_score = asyncio.run( - eval_on_QA_dataset(args.dataset, answer_provider, args.num_samples, metric) + eval_on_QA_dataset(args.dataset, context_provider, args.num_samples, args.metric_name) ) - print(f"Average {args.metric}: {avg_score}") + logger.info(f"Average {args.metric_name}: {avg_score}") diff --git a/evals/promptfoo_config_template.yaml b/evals/promptfoo_config_template.yaml new file mode 100644 index 000000000..f2201fca2 --- /dev/null +++ b/evals/promptfoo_config_template.yaml @@ -0,0 +1,7 @@ +# yaml-language-server: $schema=https://promptfoo.dev/config-schema.json + +# Learn more about building a configuration: https://promptfoo.dev/docs/configuration/guide + +description: "My eval" +providers: + - id: openai:gpt-4o-mini diff --git a/evals/promptfoo_metrics.py b/evals/promptfoo_metrics.py new file mode 100644 index 000000000..addd0030a --- /dev/null +++ b/evals/promptfoo_metrics.py @@ -0,0 +1,53 @@ +from evals.promptfoo_wrapper import PromptfooWrapper +import os +import yaml +import json +import shutil + + +class PromptfooMetric: + def __init__(self, judge_prompt): + promptfoo_path = shutil.which("promptfoo") + self.wrapper = PromptfooWrapper(promptfoo_path=promptfoo_path) + self.judge_prompt = judge_prompt + + async def measure(self, instances, context_provider): + with open(os.path.join(os.getcwd(), "evals/promptfoo_config_template.yaml"), "r") as file: + config = yaml.safe_load(file) + + config["defaultTest"] = [{"assert": {"type": "llm_rubric", "value": self.judge_prompt}}] + + # Fill config file with test cases + tests = [] + for instance in instances: + context = await context_provider(instance) + test = { + "vars": { + "name": instance["question"][:15], + "question": instance["question"], + "context": context, + } + } + tests.append(test) + config["tests"] = tests + + # Write the updated YAML back, preserving formatting and structure + updated_yaml_file_path = os.path.join(os.getcwd(), "config_with_context.yaml") + with open(updated_yaml_file_path, "w") as file: + yaml.dump(config, file) + + self.wrapper.run_eval( + prompt_file=os.path.join(os.getcwd(), "evals/promptfooprompt.json"), + config_file=os.path.join(os.getcwd(), "config_with_context.yaml"), + out_format="json", + ) + + file_path = os.path.join(os.getcwd(), "benchmark_results.json") + + # Read and parse the JSON file + with open(file_path, "r") as file: + results = json.load(file) + + self.score = results["results"]["prompts"][0]["metrics"]["score"] + + return self.score diff --git a/evals/promptfoo_wrapper.py b/evals/promptfoo_wrapper.py new file mode 100644 index 000000000..97a03bbf8 --- /dev/null +++ b/evals/promptfoo_wrapper.py @@ -0,0 +1,157 @@ +import subprocess +import json +import logging +import os +from typing import List, Optional, Dict, Generator +import shutil +import platform +from dotenv import load_dotenv + +logger = logging.getLogger(__name__) + +# Load environment variables from .env file +load_dotenv() + + +class PromptfooWrapper: + """ + A Python wrapper class around the promptfoo CLI tool, allowing you to: + - Evaluate prompts against different language models. + - Compare responses from multiple models. + - Pass configuration and prompt files. + - Retrieve the outputs in a structured format, including binary output if needed. + + This class assumes you have the promptfoo CLI installed and accessible in your environment. + For more details on promptfoo, see: https://github.com/promptfoo/promptfoo + """ + + def __init__(self, promptfoo_path: str = ""): + """ + Initialize the wrapper with the path to the promptfoo executable. + + :param promptfoo_path: Path to the promptfoo binary (default: 'promptfoo') + """ + self.promptfoo_path = promptfoo_path + logger.debug(f"Initialized PromptfooWrapper with binary at: {self.promptfoo_path}") + + def _validate_path(self, file_path: Optional[str]) -> None: + """ + Validate that a file path is accessible if provided. + Raise FileNotFoundError if it does not exist. + """ + if file_path and not os.path.isfile(file_path): + logger.error(f"File not found: {file_path}") + raise FileNotFoundError(f"File not found: {file_path}") + + def _get_node_bin_dir(self) -> str: + """ + Determine the Node.js binary directory dynamically for macOS and Linux. + """ + node_executable = shutil.which("node") + if not node_executable: + logger.error("Node.js is not installed or not found in the system PATH.") + raise EnvironmentError("Node.js is not installed or not in PATH.") + + # Determine the Node.js binary directory + node_bin_dir = os.path.dirname(node_executable) + + # Special handling for macOS, where Homebrew installs Node in /usr/local or /opt/homebrew + if platform.system() == "Darwin": # macOS + logger.debug("Running on macOS") + brew_prefix = os.popen("brew --prefix node").read().strip() + if brew_prefix and os.path.exists(brew_prefix): + node_bin_dir = os.path.join(brew_prefix, "bin") + logger.debug(f"Detected Node.js binary directory using Homebrew: {node_bin_dir}") + + # For Linux, Node.js installed via package managers should work out of the box + logger.debug(f"Detected Node.js binary directory: {node_bin_dir}") + return node_bin_dir + + def _run_command( + self, + cmd: List[str], + filename, + ) -> Generator[Dict, None, None]: + """ + Run a given command using subprocess and parse the output. + """ + logger.debug(f"Running command: {' '.join(cmd)}") + + # Make a copy of the current environment + env = os.environ.copy() + + try: + node_bin_dir = self._get_node_bin_dir() + print(node_bin_dir) + env["PATH"] = f"{node_bin_dir}:{env['PATH']}" + + except EnvironmentError as e: + logger.error(f"Failed to set Node.js binary directory: {e}") + raise + + # Add node's bin directory to the PATH + # node_bin_dir = "/Users/vasilije/Library/Application Support/JetBrains/PyCharm2024.2/node/versions/20.15.0/bin" + # # env["PATH"] = f"{node_bin_dir}:{env['PATH']}" + + result = subprocess.run(cmd, capture_output=True, text=True, check=False, env=env) + + print(result.stderr) + with open(filename, "r", encoding="utf-8") as file: + read_data = json.load(file) + print(f"{filename} created and written.") + + # Log raw stdout for debugging + logger.debug(f"Raw command output:\n{result.stdout}") + + # Use the parse_promptfoo_output function to yield parsed results + return read_data + + def run_eval( + self, + prompt_file: Optional[str] = None, + config_file: Optional[str] = None, + eval_file: Optional[str] = None, + out_format: str = "json", + extra_args: Optional[List[str]] = None, + binary_output: bool = False, + ) -> Dict: + """ + Run the `promptfoo eval` command with the provided parameters and return parsed results. + + :param prompt_file: Path to a file containing one or more prompts. + :param config_file: Path to a config file specifying models, scoring methods, etc. + :param eval_file: Path to an eval file with test data. + :param out_format: Output format, e.g., 'json', 'yaml', or 'table'. + :param extra_args: Additional command-line arguments for fine-tuning evaluation. + :param binary_output: If True, interpret output as binary data instead of text. + :return: List of parsed results (each result is a dictionary). + """ + self._validate_path(prompt_file) + self._validate_path(config_file) + self._validate_path(eval_file) + + filename = "benchmark_results" + + filename = os.path.join(os.getcwd(), f"{filename}.json") + # Create an empty JSON file + with open(filename, "w") as file: + json.dump({}, file) + + cmd = [self.promptfoo_path, "eval"] + if prompt_file: + cmd.extend(["--prompts", prompt_file]) + if config_file: + cmd.extend(["--config", config_file]) + if eval_file: + cmd.extend(["--eval", eval_file]) + cmd.extend(["--output", filename]) + if extra_args: + cmd.extend(extra_args) + + # Log the constructed command for debugging + logger.debug(f"Constructed command: {' '.join(cmd)}") + + # Collect results from the generator + results = self._run_command(cmd, filename=filename) + logger.debug(f"Parsed results: {json.dumps(results, indent=4)}") + return results diff --git a/evals/promptfooprompt.json b/evals/promptfooprompt.json new file mode 100644 index 000000000..fb6351406 --- /dev/null +++ b/evals/promptfooprompt.json @@ -0,0 +1,10 @@ +[ + { + "role": "system", + "content": "Answer the question using the provided context. Be as brief as possible." + }, + { + "role": "user", + "content": "The question is: `{{ question }}` \n And here is the context: `{{ context }}`" + } +] diff --git a/evals/qa_dataset_utils.py b/evals/qa_dataset_utils.py index c570455c4..ac97a180c 100644 --- a/evals/qa_dataset_utils.py +++ b/evals/qa_dataset_utils.py @@ -55,7 +55,7 @@ def download_qa_dataset(dataset_name: str, filepath: Path): print(f"Failed to download {dataset_name}. Status code: {response.status_code}") -def load_qa_dataset(dataset_name_or_filename: str): +def load_qa_dataset(dataset_name_or_filename: str) -> list[dict]: if dataset_name_or_filename in qa_datasets: dataset_name = dataset_name_or_filename filename = qa_datasets[dataset_name]["filename"] @@ -77,6 +77,6 @@ def load_qa_dataset(dataset_name_or_filename: str): try: validate(instance=dataset, schema=qa_json_schema) except ValidationError as e: - print("File is not a valid QA dataset:", e.message) + raise ValidationError(f"Invalid QA dataset: {e.message}") return dataset diff --git a/evals/qa_metrics_utils.py b/evals/qa_metrics_utils.py new file mode 100644 index 000000000..107fe429d --- /dev/null +++ b/evals/qa_metrics_utils.py @@ -0,0 +1,51 @@ +from evals.deepeval_metrics import ( + correctness_metric, + comprehensiveness_metric, + diversity_metric, + empowerment_metric, + directness_metric, + f1_score_metric, + em_score_metric, +) +from evals.promptfoo_metrics import PromptfooMetric +from deepeval.metrics import AnswerRelevancyMetric +import deepeval.metrics +from cognee.infrastructure.llm.prompts.llm_judge_prompts import llm_judge_prompts + +native_deepeval_metrics = {"AnswerRelevancy": AnswerRelevancyMetric} + +custom_deepeval_metrics = { + "Correctness": correctness_metric, + "Comprehensiveness": comprehensiveness_metric, + "Diversity": diversity_metric, + "Empowerment": empowerment_metric, + "Directness": directness_metric, + "F1": f1_score_metric, + "EM": em_score_metric, +} + +promptfoo_metrics = { + "promptfoo.correctness": PromptfooMetric(llm_judge_prompts["correctness"]), + "promptfoo.comprehensiveness": PromptfooMetric(llm_judge_prompts["comprehensiveness"]), + "promptfoo.diversity": PromptfooMetric(llm_judge_prompts["diversity"]), + "promptfoo.empowerment": PromptfooMetric(llm_judge_prompts["empowerment"]), + "promptfoo.directness": PromptfooMetric(llm_judge_prompts["directness"]), +} + +qa_metrics = native_deepeval_metrics | custom_deepeval_metrics | promptfoo_metrics + + +def get_metric(metric_name: str): + if metric_name in qa_metrics: + metric = qa_metrics[metric_name] + else: + try: + metric_cls = getattr(deepeval.metrics, metric_name) + metric = metric_cls() + except AttributeError: + raise Exception(f"Metric {metric_name} not supported") + + if isinstance(metric, type): + metric = metric() + + return metric