From acfbcadffc2ee79cfb358cf7c95119802b24f398 Mon Sep 17 00:00:00 2001 From: BrianShen <96436972+brianshen3@users.noreply.github.com> Date: Thu, 16 Jan 2025 09:14:34 -0600 Subject: [PATCH] feat: added logic to handle labeled answers (#742) * added logic to handle labeled answers * metrics fix * updated is_page logic to be pair of doc + page_num. Added new versioning * added notes for rules on what is correct --- .../pyproject.toml | 2 +- .../retrieval_augmented_generation/metrics.py | 55 +++++++++++++++---- .../upload_results.py | 7 ++- 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/examples/dataset/retrieval_augmented_generation/pyproject.toml b/examples/dataset/retrieval_augmented_generation/pyproject.toml index 7ccaa0b8f..5cb0c576b 100644 --- a/examples/dataset/retrieval_augmented_generation/pyproject.toml +++ b/examples/dataset/retrieval_augmented_generation/pyproject.toml @@ -9,7 +9,7 @@ license = "Apache-2.0" requires-python = ">=3.8,<3.12" dependencies = [ - "kolena>=1.50.0,<2", + "kolena>=1.51.1,<2", "s3fs>=2024.10.0", ] diff --git a/examples/dataset/retrieval_augmented_generation/retrieval_augmented_generation/metrics.py b/examples/dataset/retrieval_augmented_generation/retrieval_augmented_generation/metrics.py index 3e1614c1f..486b6491f 100644 --- a/examples/dataset/retrieval_augmented_generation/retrieval_augmented_generation/metrics.py +++ b/examples/dataset/retrieval_augmented_generation/retrieval_augmented_generation/metrics.py @@ -15,31 +15,64 @@ from retrieval_augmented_generation.constants import ID_FIELDS -def is_doc_retrieved(retrieved_contents: list, doc_name: str) -> bool: - return any([doc_name in content.locator for content in retrieved_contents]) +def is_doc_retrieved(retrieved_contents: list, doc_names: list) -> bool: + # NOTE: if one doc is retrieved, it is considered correct + return any(any(doc_name in content.locator for content in retrieved_contents) for doc_name in doc_names) -def is_page_retrieved(retrieved_contents: list, doc_name: str, relevant_pages: list) -> bool: - retrieved_pages = [content.page_number for content in retrieved_contents if doc_name in content.locator] +def is_page_retrieved(retrieved_contents: list, doc_names: list, relevant_pages: list) -> bool: + # NOTE: if one page of any doc is retrieved, it is considered correct + retrieved_pairs = [ + (content.locator.split("/")[-1].replace(".pdf", ""), content.page_number) for content in retrieved_contents + ] + return any(pair in retrieved_pairs for pair in relevant_pages) - # NOTE: all relevant pages must be retrieved to be considered correct. - return set(relevant_pages).issubset(retrieved_pages) + +def extract_doc_names(labeling_task: dict) -> list: + if labeling_task is None: + return [] + return [ + content.locator.split("/")[-1].replace(".pdf", "") + for content in labeling_task.get("retrieved_contents", []) or [] + ] + + +def extract_relevant_pages(labeling_task: dict) -> list: + if labeling_task is None: + return [] + contents = labeling_task.get("retrieved_contents", []) or [] + # Create pairs of (doc_name, page_number) + return [(content.locator.split("/")[-1].replace(".pdf", ""), content.page_number) for content in contents] def compute_metrics(df_dataset: pd.DataFrame, df_results: pd.DataFrame) -> pd.DataFrame: - ground_truth_columns = ["doc_name", "relevant_pages", "financebench_id"] + ground_truth_columns = ["doc_names", "relevant_pages", "financebench_id"] + + is_labeled = "labeling_task" in df_dataset.columns + if is_labeled: + # Extract document names and page numbers from labeling task + df_dataset["doc_names"] = df_dataset["labeling_task"].apply(extract_doc_names) + df_dataset["relevant_pages"] = df_dataset["labeling_task"].apply(extract_relevant_pages) + else: + df_dataset["doc_names"] = df_dataset["doc_name"].apply(lambda x: [x]) + # Create pairs of (doc_name, page_number) for non-labeled data + df_dataset["relevant_pages"] = df_dataset.apply( + lambda row: [(row["doc_name"], page) for page in row["relevant_pages"]], + axis=1, + ) assert set(ground_truth_columns).issubset( df_dataset.columns, ), f"ground truth columns {ground_truth_columns} cannot be found in dataset dataframe." - - df = df_results.merge(df_dataset[ground_truth_columns], on=ID_FIELDS, how="left") + # Include doc_names in the merge + columns_to_merge = ground_truth_columns + ["doc_names"] + df = df_results.merge(df_dataset[columns_to_merge], on=ID_FIELDS, how="left") metrics = [] for record in df.itertuples(): metrics.append( dict( - is_doc_retrieved=is_doc_retrieved(record.retrieved_contents, record.doc_name), - is_page_retrieved=is_page_retrieved(record.retrieved_contents, record.doc_name, record.relevant_pages), + is_doc_retrieved=is_doc_retrieved(record.retrieved_contents, record.doc_names), + is_page_retrieved=is_page_retrieved(record.retrieved_contents, record.doc_names, record.relevant_pages), ), ) diff --git a/examples/dataset/retrieval_augmented_generation/retrieval_augmented_generation/upload_results.py b/examples/dataset/retrieval_augmented_generation/retrieval_augmented_generation/upload_results.py index dba90d9ef..f5f24002d 100644 --- a/examples/dataset/retrieval_augmented_generation/retrieval_augmented_generation/upload_results.py +++ b/examples/dataset/retrieval_augmented_generation/retrieval_augmented_generation/upload_results.py @@ -45,10 +45,13 @@ def to_documents(retrieved_contents: list[dict[str, Any]]) -> list: def run(args: Namespace) -> None: model_name = MODEL_NAME[args.model] - df_results = pd.read_json(f"{S3_BUCKET}/{DATASET}/results/raw/{model_name}.jsonl", lines=True) + df_results = pd.read_json( + f"{S3_BUCKET}/{DATASET}/results/raw/{model_name}.jsonl", + lines=True, + ) df_results["retrieved_contents"] = df_results["retrieved_contents"].apply(to_documents) if args.evaluate: - df_dataset = download_dataset(args.dataset_name) + df_dataset = download_dataset(args.dataset_name, include_extracted_properties=True) df_metrics = compute_metrics(df_dataset, df_results) df_results = pd.concat([df_results, df_metrics], axis=1) upload_results(args.dataset_name, model_name, df_results)