diff --git a/evals/eval_on_hotpot.py b/evals/eval_on_hotpot.py index c6bb86ba..da102c8e 100644 --- a/evals/eval_on_hotpot.py +++ b/evals/eval_on_hotpot.py @@ -9,7 +9,7 @@ from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt from evals.qa_dataset_utils import load_qa_dataset from evals.qa_metrics_utils import get_metrics -from evals.qa_context_provider_utils import qa_context_providers, create_cognee_context_getter +from evals.qa_context_provider_utils import qa_context_providers, valid_pipeline_slices logger = logging.getLogger(__name__) @@ -97,7 +97,7 @@ async def eval_on_QA_dataset( async def incremental_eval_on_QA_dataset( dataset_name_or_filename: str, num_samples, metric_name_list ): - pipeline_slice_names = ["base", "extract_chunks", "extract_graph", "summarize"] + pipeline_slice_names = valid_pipeline_slices.keys() incremental_results = {} for pipeline_slice_name in pipeline_slice_names: diff --git a/evals/qa_context_provider_utils.py b/evals/qa_context_provider_utils.py index 0591d7c9..6397d105 100644 --- a/evals/qa_context_provider_utils.py +++ b/evals/qa_context_provider_utils.py @@ -5,6 +5,9 @@ from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string from functools import partial from cognee.api.v1.cognify.cognify_v2 import get_default_tasks +import logging + +logger = logging.getLogger(__name__) async def get_raw_context(instance: dict) -> str: @@ -24,6 +27,34 @@ async def cognify_instance(instance: dict, task_indices: list[int] = None): await cognee.cognify("QA", tasks=selected_tasks) +def _insight_to_string(triplet: tuple) -> str: + if not (isinstance(triplet, tuple) and len(triplet) == 3): + logger.warning("Invalid input: Expected a tuple of length 3.") + return "" + + node1, edge, node2 = triplet + + if not (isinstance(node1, dict) and isinstance(edge, dict) and isinstance(node2, dict)): + logger.warning("Invalid input: Each element in the tuple must be a dictionary.") + return "" + + node1_name = node1["name"] if "name" in node1 else "N/A" + node1_description = node1["description"] if "description" in node1 else node1["text"] + node1_string = f"name: {node1_name}, description: {node1_description}" + node2_name = node2["name"] if "name" in node2 else "N/A" + node2_description = node2["description"] if "description" in node2 else node2["text"] + node2_string = f"name: {node2_name}, description: {node2_description}" + + edge_string = edge.get("relationship_name", "") + + if not edge_string: + logger.warning("Missing required field: 'relationship_name' in edge dictionary.") + return "" + + triplet_str = f"{node1_string} -- {edge_string} -- {node2_string}" + return triplet_str + + async def get_context_with_cognee( instance: dict, task_indices: list[int] = None, @@ -33,9 +64,24 @@ async def get_context_with_cognee( search_results = [] for search_type in search_types: - search_results += await cognee.search(search_type, query_text=instance["question"]) + raw_search_results = await cognee.search(search_type, query_text=instance["question"]) - search_results_str = "\n".join([context_item["text"] for context_item in search_results]) + if search_type == SearchType.INSIGHTS: + res_list = [_insight_to_string(edge) for edge in raw_search_results] + else: + res_list = [ + context_item.get("text", "") + for context_item in raw_search_results + if isinstance(context_item, dict) + ] + if all(not text for text in res_list): + logger.warning( + "res_list contains only empty strings: No valid 'text' entries found in raw_search_results." + ) + + search_results += res_list + + search_results_str = "\n".join(search_results) return search_results_str @@ -47,11 +93,7 @@ def create_cognee_context_getter( async def get_context_with_simple_rag(instance: dict) -> str: - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - for title, sentences in instance["context"]: - await cognee.add("\n".join(sentences), dataset_name="QA") + await cognify_instance(instance) vector_engine = get_vector_engine() found_chunks = await vector_engine.search("document_chunk_text", instance["question"], limit=5) @@ -72,10 +114,14 @@ async def get_context_with_brute_force_triplet_search(instance: dict) -> str: valid_pipeline_slices = { - "base": [0, 1, 5], - "extract_chunks": [0, 1, 2, 5], - "extract_graph": [0, 1, 2, 3, 5], - "summarize": [0, 1, 2, 3, 4, 5], + "extract_graph": { + "slice": [0, 1, 2, 3, 5], + "search_types": [SearchType.INSIGHTS, SearchType.CHUNKS], + }, + "summarize": { + "slice": [0, 1, 2, 3, 4, 5], + "search_types": [SearchType.INSIGHTS, SearchType.SUMMARIES, SearchType.CHUNKS], + }, } qa_context_providers = { @@ -84,6 +130,8 @@ async def get_context_with_brute_force_triplet_search(instance: dict) -> str: "simple_rag": get_context_with_simple_rag, "brute_force": get_context_with_brute_force_triplet_search, } | { - name: create_cognee_context_getter(task_indices=slice) - for name, slice in valid_pipeline_slices.items() + name: create_cognee_context_getter( + task_indices=value["slice"], search_types=value["search_types"] + ) + for name, value in valid_pipeline_slices.items() } diff --git a/evals/run_qa_eval.py b/evals/run_qa_eval.py index 97bea184..f9f35d61 100644 --- a/evals/run_qa_eval.py +++ b/evals/run_qa_eval.py @@ -1,5 +1,5 @@ import asyncio -from evals.eval_on_hotpot import eval_on_QA_dataset +from evals.eval_on_hotpot import eval_on_QA_dataset, incremental_eval_on_QA_dataset from evals.qa_eval_utils import get_combinations, save_results_as_image import argparse from pathlib import Path @@ -15,19 +15,26 @@ async def run_evals_on_paramset(paramset: dict, out_path: str): num_samples = params["num_samples"] rag_option = params["rag_option"] - result = await eval_on_QA_dataset( - dataset, - rag_option, - num_samples, - paramset["metric_names"], - ) - if dataset not in results: results[dataset] = {} if num_samples not in results[dataset]: results[dataset][num_samples] = {} - results[dataset][num_samples][rag_option] = result + if rag_option == "cognee_incremental": + result = await incremental_eval_on_QA_dataset( + dataset, + num_samples, + paramset["metric_names"], + ) + results[dataset][num_samples] |= result + else: + result = await eval_on_QA_dataset( + dataset, + rag_option, + num_samples, + paramset["metric_names"], + ) + results[dataset][num_samples][rag_option] = result with open(json_path, "w") as file: json.dump(results, file, indent=1)